天天看點

(七)SN-GAN論文筆記與實戰

(七)SN-GAN論文筆記與實戰

        • 一、論文筆記
        • 二、完整代碼
        • 三、遇到的問題及解決

一、論文筆記

在WGAN-GP中使用gradient penalty 的方法來限制判别器,但這種放法隻能對生成資料分布與真實分布之間的分布空間的資料做梯度懲罰,無法對整個空間的資料做懲罰。這會導緻随着訓練的進行,生成資料分布與真實資料分布之間的空間會逐漸變化,進而導緻gradient penalty 正則化方式不穩定。此外,WGAN-GP涉及比較多的運算,是以訓練WGAN-GP的網絡也比較耗時。

SN-GAN提出使用Spectral Normalization(譜歸一化)的方法來讓判别器D滿足Lipschitz限制,簡單而言,SN-GAN隻需要改變判别器權值矩陣的最大奇異值,這種方法可以最大限度地儲存判别器權值矩陣的資訊,這個優勢可以讓SN-GAN使用類别較多的資料集作為訓練資料,依舊可以獲得比較好的生成效果。

從SN-GAN 論文中的實際效果來看,SN-GAN是目前僅有的可以使用單個生成器與判别器從ImageNet資料集(其中的圖像有非常多的類别)生成高品質圖像的GAN模型,WGAN、WGAN-GP等GAN模型在多類别圖像中無法生成高品質的圖像。其中一個可能的原因就是,在訓練過程中,WGAN、WGAN-GP等GAN模型喪失了較多的原始資訊。

簡單而言,SN-GAN具有如下優勢:

  1. 以Spectral Normalization 方法讓判别器D滿足Lipschitz限制,Lipschitz的常數K是唯一需要調整的超參數。
  2. 整體上SN-GAN隻改變判别器權值矩陣的最大奇異值,進而可以最大限度地保留原始資訊。
  3. 具體訓練模型時,使用power iteration(疊代法),加快訓練速度,可比WGAN-GP快許多。WGAN-GP慢的原因是使用gradient penalty後,模型在梯度下降的過程中相當于計算兩次梯度,計算量更大,是以整體訓練速度就變慢了。

判别器的目标就是對判别器的所有權重都做 W ‘ W^‘ W‘ = W W W / ∣ ∣ W ∣ ∣ 2 ||W||_2 ∣∣W∣∣2​ (譜歸一化) 證明見原論文

但是由于直接計算譜範數 ∣ ∣ W ∣ ∣ 2 ||W||_2 ∣∣W∣∣2​是比較耗時的,是以為了讓訓練模型速度更快,就需要使用一個技巧。power iteration(幂疊代)方法通過疊代計算的思想可以比較快的計算出譜範數的近似值。

因為譜範數 ∣ ∣ W ∣ ∣ 2 ||W||_2 ∣∣W∣∣2​等于 W T W^T WT W W W的最大特征根,是以要求解譜範數,就可以轉變為求 W T W^T WT W W W的最大特征根,使用power iteration的方法如下:

(七)SN-GAN論文筆記與實戰

二、完整代碼

代碼跑不出來效果,沒找到原因,不知道問題出在哪裡,希望後面随着學習的深入再來看看。

import torch
import torchvision
import torch.nn as nn
import argparse
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import os
from torch.optim.lr_scheduler import LambdaLR
import random
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type = int, default = 200)
parser.add_argument('--batch_size', type = int, default = 64)
parser.add_argument('--lr', type = float, default = 0.0002)
parser.add_argument('--b1', type = float, default = 0.5)
parser.add_argument('--b2', type = float, default = 0.999)
parser.add_argument('--decay_epochs', type = int, default=100)
parser.add_argument('--z_dim', type = int, default=128, help = 'latent vector')

opt = parser.parse_args(args = [])
print(opt)
random.seed(22)
torch.manual_seed(22)
os.makedirs('Picture/SNGAN', exist_ok = True)
os.makedirs('runs/SNGAN', exist_ok = True)
os.makedirs('Model/SNGAN', exist_ok = True)
device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')

'''加載資料集'''
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])
# 50000張圖檔用作訓練集
train_set = torchvision.datasets.CIFAR10(root = '../dataset', train=True, transform=transform, download=False)
train_loader = DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, num_workers=0)
print(train_set[0][0].shape)

'''自定義學習率類'''
class LambdaLR:
    def __init__(self, n_epochs, decay_epochs):
        self.n_epochs = n_epochs
        self.decay_epochs = decay_epochs
    def step(self, epoch):
        return 1.0 - max(0, (epoch - self.decay_epochs)/(self.n_epochs - self.decay_epochs))
    
'''Spectral Normalization -- 譜歸一化類'''
class SpectralNorm(nn.Module):
    def __init__(self, layer, name = 'weight', power_iterations = 1):
        super(SpectralNorm, self).__init__()
        '''params:
        layer: 傳入的需要使得參數譜歸一化的網路層
        name : 譜歸一化的參數
        power_iterations:幂疊代的次數,論文中提到,實際上疊代一次已足夠
        '''
        self.layer = layer
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params(): # 如果疊代參數未初始化,則初始化
            self._make_params()
            
    def _update_u_v(self):
        u = getattr(self.layer, self.name+'_u')
        v = getattr(self.layer, self.name+'_v')
        w = getattr(self.layer, self.name+'_bar')
        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = self.l2Norm(torch.mv(torch.t(w.view(height, -1).data), u.data)) # 計算:v <- (W^t*u)/||W^t*u||   2範數
            u.data = self.l2Norm(torch.mv(w.view(height, -1).data, v.data)) # 計算:u <- (Wv)/||Wv||
        sigma = u.dot(w.view(height, -1).mv(v)) # 計算 W的譜範數 ≈ u^t * W * v
        setattr(self.layer, self.name, w/sigma.expand_as(w))
        
    def _made_params(self):
        # 存在這些參數則傳回True, 否則傳回False
        try:
            u = getattr(self.layer, self.name + '_u')
            v = getattr(self.layer, self.name + '_v')
            w = getattr(self.layer, self.name + '_bar')
            return True
        except AttributeError:
            return False
    def _make_params(self):
        w = getattr(self.layer, self.name)
        height = w.data.shape[0] # 輸出的卷積核的數目
        width = w.view(height, -1).data.shape[1] # width為 in_feature*kernel*kernel 的值
        # .new()建立一個新的Tensor,該Tensor的type和device都和原有Tensor一緻
        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad = False)
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad = False)
        u.data = self.l2Norm(u.data)
        v.data = self.l2Norm(v.data)
        w_bar = nn.Parameter(w.data)
        del self.layer._parameters[self.name] # 删除以前的weight參數
        # 注冊參數
        self.layer.register_parameter(self.name+'_u', u) # 傳入的值u,v必須是Parameter類型
        self.layer.register_parameter(self.name+'_v', v)
        self.layer.register_parameter(self.name+'_bar', w_bar)
        
    def l2Norm(self, v, eps = 1e-12): # 用于計算例如:v/||v||
        return v/(v.norm() + eps) 
    
    def forward(self, *args):
        self._update_u_v()
        return self.layer.forward(*args)

'''網絡模型'''
# DCGAN-like generator and discriminator
class Generator(nn.Module):
    def __init__(self, z_dim):
        self.z_dim = z_dim
        super(Generator,self).__init__()
        self.model = nn.Sequential( # 輸入shape[b, z_dim, 1, 1]
            nn.ConvTranspose2d(z_dim, 512, 4, stride=1, bias=False), # --> [b, 512, 4, 4]
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False), # --> [b, 256, 8, 8]
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False), # --> [b, 128. 16, 16]
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False), # --> [b, 64, 32, 32]
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, 3, 3, stride=1, padding=1, bias=False), # --> [b, 3, 32, 32]
            nn.Tanh()
        )
    def forward(self, z):
        return self.model(z)
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = SpectralNorm(nn.Conv2d(3, 64, 3, 1, 1))
        self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, 2, 1))
        self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, 1, 1))
        self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, 2, 1))
        self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, 1, 1))
        self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, 2, 1))
        self.conv7 = SpectralNorm(nn.Conv2d(256, 512, 3, 1, 1))
        self.fc = SpectralNorm(nn.Linear(4*4*512, 1))
        
    def forward(self, img):
        img = nn.LeakyReLU(0.1)(self.conv1(img))
        img = nn.LeakyReLU(0.1)(self.conv2(img))
        img = nn.LeakyReLU(0.1)(self.conv3(img))
        img = nn.LeakyReLU(0.1)(self.conv4(img))
        img = nn.LeakyReLU(0.1)(self.conv5(img))
        img = nn.LeakyReLU(0.1)(self.conv6(img))
        img = nn.LeakyReLU(0.1)(self.conv7(img))
        
        return self.fc(img.view(-1, 4*4*512))
    
generator = Generator(opt.z_dim).to(device)
discriminator = Discriminator().to(device)
print(generator)
print(discriminator)

test_noise = torch.randn(64, opt.z_dim, 1, 1, device = device)

# because the spectral normalization module creates parameters that don't require gradients (u and v), we don't want to 
# optimize these using sgd. We only let the optimizer operate on parameters that _do_ require gradients
# TODO: replace Parameters with buffers, which aren't returned from .parameters() method.
optim_G = torch.optim.Adam(generator.parameters(), lr = opt.lr, betas=(opt.b1, opt.b2))
optim_D = torch.optim.Adam(filter(lambda p : p.requires_grad, discriminator.parameters()), 
                           lr = opt.lr, betas=(opt.b1, opt.b2))
lr_schedual_G = torch.optim.lr_scheduler.LambdaLR(optim_G, lr_lambda=LambdaLR(opt.n_epochs, opt.decay_epochs).step)
lr_schedual_D = torch.optim.lr_scheduler.LambdaLR(optim_D, lr_lambda=LambdaLR(opt.n_epochs, opt.decay_epochs).step)

'''訓練'''
writer = SummaryWriter('runs/SNGAN')
for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(train_loader):
        ############################
        #     discriminator
        ###########################
        b_size = imgs.size(0)
        optim_D.zero_grad()
        z = torch.randn(b_size, opt.z_dim, 1, 1, device = device)
        real_imgs = imgs.to(device)
        fake_imgs = generator(z).detach()

        loss_D = torch.mean(discriminator(fake_imgs)) - torch.mean(discriminator(real_imgs))
        loss_D.backward()
        optim_D.step()
        
        ############################
        #      generator
        ###########################
        if i % 5 == 0:
            optim_G.zero_grad()
            fake_imgs = generator(z)
            loss_G = -torch.mean(discriminator(fake_imgs))
            loss_G.backward()
            optim_G.step()

            print('[Epoch {}/{}] [step {}/{}] [D_loss {}] [G_loss {}]'.format(epoch, opt.n_epochs, 
                            i, len(train_loader), loss_D, loss_G))
        writer.add_scalar('D_loss', loss_D, epoch)
        writer.add_scalar('G_loss', loss_G, epoch)
    
    lr_schedual_D.step()
    lr_schedual_G.step()
    
    with torch.no_grad():
        gen_imgs = generator(test_noise)
        torchvision.utils.save_image(gen_imgs.data, 'Picture/SNGAN/generator_{}.png'.format(epoch), nrow=8, normalize=True)

           

三、遇到的問題及解決

一、Python的hasattr() getattr() setattr() 函數使用方法詳解

二、Pytorch中.new()的作用詳解

三、檢視模型的層和參數資訊的幾種方式

四、pytorch中tensor.expand()和tensor.expand_as()函數解讀

五、pytorch中的register_parameter()和parameter()

六、為什麼spectral norm對應的SNGAN未使用WGAN的loss?

七、[

繼續閱讀