天天看點

生成式對抗網絡(GANs)綜述

GAN

GAN簡介

生成式對抗網絡(Generative adversarial networks,GANs)的核心思想源自于零和博弈,包括生成器和判别器兩個部分。生成器接收随機變量并生成“假”樣本,判别器則用于判斷輸入的樣本是真實的還是合成的。兩者通過互相對抗來獲得彼此性能的提升。判别器所作的其實就是一個二分類任務,我們可以計算他的損失并進行反向傳播求出梯度,進而進行參數更新。

生成式對抗網絡(GANs)綜述

GAN的優化目标可以寫作:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] \large {\min_G\max_DV(D,G)= \mathbb{E}_{x\sim p_{data}}[\log D(x)]+\mathbb{E}_{z\sim p_z(z)}[log(1-D(G(z)))]} Gmin​Dmax​V(D,G)=Ex∼pdata​​[logD(x)]+Ez∼pz​(z)​[log(1−D(G(z)))]

其中 log ⁡ D ( x ) \log D(x) logD(x)代表了判别器鑒别真實樣本的能力,而 D ( G ( z ) ) D(G(z)) D(G(z))則代表了生成器欺騙判别器的能力。在實際的訓練中,生成器和判别器采取交替訓練,即先訓練D,然後訓練G,不斷往複。

WGAN

在上一部分我們給出了GAN的優化目标,這個目标的本質是在最小化生成樣本與真實樣本之間的JS距離。但是在實驗中發現,GAN的訓練非常的不穩定,經常會陷入坍縮模式。這是因為,在高維空間中,并不是每個點都可以表示一個樣本,而是存在着大量不代表真實資訊的無用空間。當兩個分布沒有重疊時,JS距離不能準确的提供兩個分布之間的差異。這樣的生成器,很難“捕捉”到低維空間中的真實資料分布。是以,WGAN(Wasserstein GAN)的作者提出了Wasserstein距離(推土機距離)的概念,其公式可以進行如下表示:

W ( P r , P g ) = inf ⁡ γ ∼ ∏ P r , P g E ( x , y )   γ [ ∥ x − y ∥ ] W(\mathbb P_r,\mathbb P_g)=\inf_{\gamma\sim\prod{\mathbb P_r,\mathbb P_g}}\mathbb E_{(x,y)~\gamma}[\|x-y\|] W(Pr​,Pg​)=γ∼∏Pr​,Pg​inf​E(x,y) γ​[∥x−y∥]

這裡 ∏ P r , P g \prod{\mathbb P_r,\mathbb P_g} ∏Pr​,Pg​指的是真實分布 P r \mathbb P_r Pr​和生成分布 P g \mathbb P_g Pg​的聯合分布所構成的集合, ( x , y ) (x,y) (x,y)是從 γ \gamma γ中取得的一個樣本。枚舉兩者之間所有可能的聯合分布,計算其中樣本間的距離 ∥ x − y ∥ \|x-y\| ∥x−y∥,并取其期望。而Wasserstein距離就是兩個分布樣本距離期望的下界值。這個簡單的改進,使得生成樣本在任意位置下都能給生成器帶來合适的梯度,進而對參數進行優化。

DCGAN

卷積神經網絡近年來取得了耀眼的成績,展現了其在圖像處理領域獨特的優勢。很自然的會想到,如果将卷積神經網絡引入GAN中,是否可以帶來效果上的提升呢?DCGAN(Deep Convolutional GANs)在GAN的基礎上優化了網絡結構,用完全的卷積替代了全連接配接層,去掉池化層,并采用批标準化(Batch Normalization,BN)等技術,使得網絡更容易訓練。

生成式對抗網絡(GANs)綜述

用DCGAN生成圖像

為了更友善準确的說明DCGAN的關鍵環節,這裡用一個簡化版的模型執行個體來說明。代碼基于pytorch深度學習架構,資料集采用MNIST

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import os 
#定義一些超參數
nc = 1    				#輸入圖像的通道數
nz = 100   				#輸入噪聲的次元
num_epochs = 200		#疊代次數
batch_size = 64			#批量大小
sample_dir = 'gan_samples'
# 結果的儲存目錄
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# 加載MNIST資料集
trans = transforms.Compose([
                transforms.ToTensor(),transforms.Normalize([0.5], [0.5])])
mnist = torchvision.datasets.MNIST(root=r'G:\VsCode\ml\mnist',
                                   train=True,
                                   transform=trans,
                                   download=False)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)
           

判别器&生成器

判别器使用LeakyReLU作為激活函數,最後經過Sigmoid輸出,用于真假二分類

生成器使用ReLU作為激活函數,最後經過tanh将輸出映射在 [ − 1 , 1 ] [-1,1] [−1,1]之間

# 建構判别器
class Discriminator(nn.Module):
    def __init__(self, in_channel=1, num_classes=1):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            # 28 -> 14
            nn.Conv2d(nc, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 14 -> 7
            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 7 -> 4
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(4),
        )
        self.fc = nn.Sequential(
            # reshape input, 128 -> 1
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, x, label=None):
        y_ = self.conv(x)
        y_ = y_.view(y_.size(0), -1)
        y_ = self.fc(y_)
        return y_

# 建構生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(nz, 4*4*512),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            # input: 4 by 4, output: 7 by 7
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # input: 7 by 7, output: 14 by 14
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # input: 14 by 14, output: 28 by 28
            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )
        
    def forward(self, x, label=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 512, 4, 4)
        y_ = self.conv(y_)
        return y_
           

訓練模型

# 使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
D = Discriminator().to(device)
G = Generator().to(device)
# 損失函數及優化器
criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(data_loader):
        images = images.to(device)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        #————————————————————訓練判别器——————————————————————
        #鑒别真實樣本
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        #鑒别生成樣本
        z = torch.randn(batch_size, nz).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs        
     	#計算梯度及更新
        d_loss = d_loss_real + d_loss_fake      
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        #————————————————————訓練生成器——————————————————————
        z = torch.randn(batch_size, nz).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        #計算梯度及更新
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    # 儲存生成圖檔
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
# 儲存模型
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')
           

可視化結果

reconsPath = './gan_samples/fake_images-200.png'
Image = mpimg.imread(reconsPath)
plt.imshow(Image)
plt.axis('off')
plt.show()
           

cGAN

在之前介紹的幾種模型中,我們注意到生成器的輸入都是一個随機的噪聲。可以認為這個高維噪聲向量提供了一些關鍵資訊,而生成器根據自己的了解将這些資訊進行補充,最終生成需要的圖像。生成器生成圖檔的過程是完全随機的。例如上述的MNIST資料集,我們不能控制它生成的是哪一個數字。那麼,有沒有方法可以對其做一定的限制限制,來讓生成器生成我們想要的結果呢?cGAN(Conditional Generative Adversarial Nets)通過增一個額外的向量y對生成器進行限制。以MNIST分類為例,限制資訊y可以取10維的向量,對于類别進行one-hot編碼,并與噪聲進行拼接一起輸入生成器。同樣的,判别器也将原來的輸入和y進行拼接。作者通過各種實驗證明了這個簡單的改進确實可以起到對生成器的限制作用。

生成式對抗網絡(GANs)綜述

判别器&生成器

隻需要在前向傳播的過程中加入限制變量y,我們很容易就能得到cGAN的生成器和判别器模型

class Discriminator(nn.Module):
    ...
    def forward(self, x, label):
        label = label.unsqueeze(2).unsqueeze(3)
        label = label.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat(tensors=(x, label), dim=1)
        y_ = self.conv(x)
        ...
class Generator(nn.Module):
    ...
    def forward(self, x, label):
        x = x.unsqueeze(2).unsqueeze(3)
        label = label.unsqueeze(2).unsqueeze(3)
        x = torch.cat(tensors=(x, label), dim=1)
        y_ = self.fc(x)
        ...
           

Pix2Pix

在上面的cGAN例子中,我們的控制資訊取的是想要圖像的标簽,如果這個控制資訊更加的豐富,例如輸入一整張圖像,那麼它能否完成一些更加進階的任務?Pix2Pix(Image-to-Image Translation with Conditional Adversarial Networks)将這一類問題歸納為圖像到圖像的翻譯,其使用改進後的U-net作為生成器,并設計了新穎的Patch-D判别器結構來輸出高清的圖像。Patch-D是指,不管網絡所使用的輸入圖像有多大,都将其切割成若幹個固定大小的Patch,判别器隻需對這些Patch的真假進行判斷。因為L1損失已經可以衡量生成圖像和真實圖像的全局差異,是以作者認為判别器隻需要用Patch-D這樣更關注于局部差異的結構即可。同時Patch-D的結構使得網絡的輸入變小,減少了計算量并且增大了架構的擴充性。

生成式對抗網絡(GANs)綜述

CycleGAN

Pix2Pix雖然可以生成高清的圖像,但其存在一個緻命的缺點:需要互相配對的圖檔x與y。在現實生活中,這樣成對的圖檔很難或者根本不可能搜集到,這就大大的限制了Pix2Pix的應用。對此,CycleGAN(Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks)提出了不需要配對的圖像翻譯方法。

生成式對抗網絡(GANs)綜述

CycleGAN其實就是一個X->Y的單向GAN上再加一個Y->X的單向GAN,構成一個“循環”。網絡的結構和單次訓練過程如下(圖檔來自于量子位):

生成式對抗網絡(GANs)綜述
生成式對抗網絡(GANs)綜述

除了經典的基礎GAN損失之外,CycleGAN還引入了Consistency loss的概念。循環一緻損失使得X->Y轉變的過程中必須保留有X的部分特性。循環損失的公式如下:

L c y c ( G , F ) = E x ∼ p d a t a ( x ) [ ∥ F ( G ( x ) ) − x ∥ 1 ] + E y ∼ p d a t a ( y ) [ ∥ G ( F ( x ) ) − y ∥ 1 ] L_{cyc}(G,F)=\mathbb E_{x\sim p_{data}(x)}[\|F(G(x))-x\|_1]+\mathbb E_{y\sim p_{data}(y)}[\|G(F(x))-y\|_1] Lcyc​(G,F)=Ex∼pdata​(x)​[∥F(G(x))−x∥1​]+Ey∼pdata​(y)​[∥G(F(x))−y∥1​]

兩個判别器的損失表示如下:

L G A N ( G , D Y , X , Y ) = E y ∼ p d a t a ( y ) [ l o g D Y ( y ) ] + E x ∼ p d a t a ( x ) [ l o g ( 1 − D Y ( G ( x ) ) ) ] \textit{L}_{GAN}(G,D_Y,X,Y)=\mathbb E_{y\sim p_{data}(y)}[logD_Y(y)]+\mathbb E_{x\sim p_{data}(x)}[log(1-D_Y(G(x)))] LGAN​(G,DY​,X,Y)=Ey∼pdata​(y)​[logDY​(y)]+Ex∼pdata​(x)​[log(1−DY​(G(x)))]

L G A N ( F , D X , Y , X ) = E x ∼ p d a t a ( x ) [ l o g D X ( x ) ] + E y ∼ p d a t a ( y ) [ l o g ( 1 − D X ( F ( y ) ) ) ] \textit{L}_{GAN}(F,D_X,Y,X)=\mathbb E_{x\sim p_{data}(x)}[logD_X(x)]+\mathbb E_{y\sim p_{data}(y)}[log(1-D_X(F(y)))] LGAN​(F,DX​,Y,X)=Ex∼pdata​(x)​[logDX​(x)]+Ey∼pdata​(y)​[log(1−DX​(F(y)))]

最後網絡的優化目标可以表示為

min ⁡ G X → Y , G Y → X max ⁡ D X , D Y L ( G , F , D x , D y ) = L G A N ( G , D Y , X , Y ) + L G A N ( F , D X , Y , X ) + λ L c y c ( G , F ) \min _{G_{X\rightarrow Y},G_{Y\rightarrow X}}\max_{D_X,D_Y} L(G,F,D_x,D_y)=L_{GAN}(G,D_Y,X,Y)+L_{GAN}(F,D_X,Y,X)+\lambda L_{cyc}(G,F) GX→Y​,GY→X​min​DX​,DY​max​L(G,F,Dx​,Dy​)=LGAN​(G,DY​,X,Y)+LGAN​(F,DX​,Y,X)+λLcyc​(G,F)

Pix2Pix以及CycleGAN的官方複現入口:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

StarGAN

Pix2Pix解決了有配對圖像的翻譯問題,CycleGAN解決了無配對圖像的翻譯問題,然而他們所作的圖像到圖像翻譯,都是一對一。假設現在需要将人臉轉換為喜怒哀樂四個表情,那麼他們就需要進行4次不同的訓練,這無疑會耗費巨大的計算資源。針對于這個問題,StarGAN(StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation)借助cGAN的思想,在網絡輸入中加入一個域的控制資訊。對于判别器,其不僅需要鑒别樣本是否真實,還需要判斷輸入的圖像來自哪個域。StarGAN的訓練過程如下:

  1. 将原始圖檔 c c c和目标生成域 c c c進行拼接後丢入生成器得到生成圖像 G ( x , c ) G(x,c) G(x,c)
  2. 将生成圖像 G ( x , c ) G(x,c) G(x,c)和真實圖像 y y y分别丢入判别器D,判别器除了需要判斷輸入圖像的真僞之外,還需要判斷它來自哪個域
  3. 将生成圖像 G ( x , c ) G(x,c) G(x,c)和原始生成域 c ′ c' c′丢入生成器生成重構圖檔(為了對生成器生成的圖像做進一步的限制,與CycleGAN的重構損失類似)
生成式對抗網絡(GANs)綜述

了解了StarGAN的訓練過程,我們很容易得到其損失函數各項的表達形式

首先是GAN的一般損失,這裡作者采用了前文所述的WGAN的損失形式:

KaTeX parse error: Got function '\hat' with no arguments as subscript at position 110: … x}[(\|\nabla _\̲h̲a̲t̲ ̲xD_{src}(\hat x…

對于判别器,我們需要鼓勵其将輸入圖像正确的分類到目标域c‘(原始生成域):

L s r c r = E x , c ′ [ − l o g D c l s ( c ′ ∣ x ) ] L_{src}^r=\mathbb E_{x,c'}[-logD_{cls}(c'|x)] Lsrcr​=Ex,c′​[−logDcls​(c′∣x)]

對于生成器,我們需要鼓勵其成功欺騙判别器将圖檔分類到目标域c(目标生成域),此外,生成器還需要在以生成圖像和原始生成域c’的輸入下成功将圖像還原回去,這兩部分的損失表示如下:

L s r c f = R x , c [ − l o g D c l s ( c ∣ G ( x , c ) ) ] L_{src}^f=\mathbb R_{x,c}[-logD_{cls}(c|G(x,c))] Lsrcf​=Rx,c​[−logDcls​(c∣G(x,c))]

L r e c = E x , c , c ′ [ ∥ x − G ( G ( x , c ) , c ′ ) ∥ 1 ] L_{rec}=\mathbb E_{x,c,c'}[\|x-G(G(x,c),c')\|_1] Lrec​=Ex,c,c′​[∥x−G(G(x,c),c′)∥1​]

各部分損失乘上自己的權重加總後就構成了判别器和生成器的總損失:

L D = − L a d v + λ c l s L c l s r L_D=-L_{adv}+\lambda_{cls}L_{cls}^{r} LD​=−Ladv​+λcls​Lclsr​

L G = L a d v + λ c l s L c l a s f + λ r e c L r e c L_G=L_{adv}+\lambda_{cls}L_{clas}^f+\lambda_{rec}L_{rec} LG​=Ladv​+λcls​Lclasf​+λrec​Lrec​

此外,為了更具備通用性,作者還加入了mask vector來适應不同的資料集之間的訓練。

總結

名稱 創新點
DCGAN 首次将卷積神經網絡引入GAN中
cGAN 通過拼接标簽資訊來控制生成器的輸出
Pix2Pix 提出了一種圖像到圖像翻譯的通用方法
CycleGAN 解決了Pix2Pix需要圖像配對的問題
StarGAN 提出了一種一對多的圖像到圖像的翻譯方法
InfoGAN 基于cGAN改進,提出一種無監督的生成方法,适用于不知道圖像标簽的情況
LSGAN 用最小二乘損失函數代替原始GAN的損失函數,緩解了訓練不穩定、生成圖像缺乏多樣性的問題
ProGAN 在訓練期間逐漸添加新的高分辨率層,可以生成高分辨率的圖像
SAGAN 将注意力機制引入GAN當中,簡約高效利用了全局資訊

本文列舉了生成式對抗網絡在發展過程中一些具有代表性的網絡結構。GANs如今已廣泛應用于圖像生成、圖像去噪、超分辨、文本到圖像的翻譯等各個領域,且在近幾年的研究中湧現了很多優秀的論文。感興趣的同學可以從下面的連結中pick自己想要了解的GAN~

  • THE-GAN-ZOO:彙總了各種GAN的論文及代碼位址。
  • GAN Timeline:按照時間線對不同的GAN進行了排序。
  • Browse state-of-the-art:将ArXiv上的最新論文與GitHub代碼相關聯,并做了比較排序,涉及了深度學習的各個方面。

參考文獻

  1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.
  2. Arjovsky M, Chintala S, Bottou L. Wasserstein gan[J]. arXiv preprint arXiv:1701.07875, 2017.
  3. Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.
  4. Mirza M, Osindero S. Conditional generative adversarial nets[J]. arXiv preprint arXiv:1411.1784, 2014.
  5. Isola P, Zhu J Y, Zhou T, et al. Image-to-image translation with conditional adversarial networks[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 1125-1134.
  6. Zhu J Y, Park T, Isola P, et al. Unpaired image-to-image translation using cycle-consistent adversarial networks[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2223-2232.
  7. Choi Y, Choi M, Kim M, et al. Stargan: Unified generative adversarial networks for multi-domain image-to-image translation[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 8789-8797.
  8. Mao X, Li Q, Xie H, et al. Least squares generative adversarial networks[C]//Proceedings of the IEEE international conference on computer vision. 2017: 2794-2802.
  9. Karras T, Aila T, Laine S, et al. Progressive growing of gans for improved quality, stability, and variation[J]. arXiv preprint arXiv:1710.10196, 2017.
  10. Chen X, Duan Y, Houthooft R, et al. Infogan: Interpretable representation learning by information maximizing generative adversarial nets[C]//Advances in neural information processing systems. 2016: 2172-2180.
  11. Zhang H, Goodfellow I, Metaxas D, et al. Self-attention generative adversarial networks[C]//International Conference on Machine Learning. 2019: 7354-7363.

繼續閱讀