天天看點

【PyTorch】50行代碼實作GAN——PyTorch一、什麼是GAN二、用PyTorch訓練GAN

本文來源于PyTorch中文網。

一直想了解GAN到底是個什麼東西,卻一直沒能騰出時間來認真研究,前幾日正好搜到一篇關于PyTorch實作GAN訓練的文章,特将學習記錄如下,本文主要包含兩個部分:GAN原理介紹和技術層面實作。

一、什麼是GAN

2014 年,Ian Goodfellow 和他在蒙特利爾大學的同僚發表了一篇震撼學界的論文。沒錯,我說的就是《Generative Adversarial Nets》,這标志着生成對抗網絡(GAN)的誕生,而這是通過對計算圖和博弈論的創新性結合。他們的研究展示,給定充分的模組化能力,兩個博弈模型能夠通過簡單的反向傳播(backpropagation)來協同訓練。

這兩個模型的角色定位十分鮮明。給定真實資料集 R,G 是生成器(generator),它的任務是生成能以假亂真的假資料;而 D 是判别器 (discriminator),它從真實資料集或者 G 那裡擷取資料, 然後做出判别真假的标記。Ian Goodfellow 的比喻是,G 就像一個赝品作坊,想要讓做出來的東西盡可能接近真品,蒙混過關。而 D 就是文物鑒定專家,要能區分出真品和高仿(但在這個例子中,造假者 G 看不到原始資料,而隻有 D 的鑒定結果——前者是在盲幹)。

【PyTorch】50行代碼實作GAN——PyTorch一、什麼是GAN二、用PyTorch訓練GAN

理想情況下,D 和 G 都會随着不斷訓練,做得越來越好——直到 G 基本上成為了一個“赝品制造大師”,而 D 因無法正确區分兩種資料分布輸給 G。

實踐中,Ian Goodfellow 展示的這項技術在本質上是:G 能夠對原始資料集進行一種無監督學習,找到以更低次元的方式(lower-dimensional manner)來表示資料的某種方法。而無監督學習之是以重要,就好像 Yann LeCun 的那句話:“無監督學習是蛋糕的糕體”。這句話中的蛋糕,指的是無數學者、開發者苦苦追尋的“真正的 AI”。

二、用PyTorch訓練GAN

Dev Nag:在表面上,GAN 這門如此強大、複雜的技術,看起來需要編寫天量的代碼來執行,但事實未必如此。我們使用 PyTorch,能夠在 50 行代碼以内建立出簡單的 GAN 模型。這之中,其實隻有五個部分需要考慮:

  • R:原始、真實資料集
  • I:作為熵的一項來源,進入生成器的随機噪音
  • G:生成器,試圖模仿原始資料
  • D:判别器,試圖差別 G 的生成資料和 R
  • 我們教 G 糊弄 D、教 D 當心 G 的“訓練”環。

R:在我們的例子裡,從最簡單的 R 着手——貝爾曲線(bell curve)。它把平均數(mean)和标準差(standard deviation)作為輸入,然後輸出能提供樣本資料正确圖形(從 Gaussian 用這些參數獲得 )的函數。在我們的代碼例子中,我們使用 4 的平均數和 1.25 的标準差。

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian
           

I:生成器的輸入是随機的,為提高點難度,我們使用均勻分布(uniform distribution )而非标準分布。這意味着,我們的 Model G 不能簡單地改變輸入(放大/縮小、平移)來複制 R,而需要用非線性的方式來改造資料。

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian
           

G: 該生成器是個标準的前饋圖(feedforward graph)——兩層隐層,三個線性映射(linear maps)。我們使用了 ELU (exponential linear unit)。G 将從 I 獲得平均分布的資料樣本,然後找到某種方式來模仿 R 中标準分布的樣本。

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
 
    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)
           

D: 判别器的代碼和 G 的生成器代碼很接近。一個有兩層隐層和三個線性映射的前饋圖。它會從 R 或 G 那裡獲得樣本,然後輸出 0 或 1 的判别值,對應反例和正例。這幾乎是神經網絡的最弱版本了。

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
 
    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))
           

最後,訓練環在兩個模式中變幻:第一步,用被準确标記的真實資料 vs. 假資料訓練 D;随後,訓練 G 來騙過 D,這裡是用的不準确标記。道友們,這是正邪之間的較量。

for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()
 
        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params
 
        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
 
    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()
 
        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine
 
        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters
           

即便你從沒接觸過 PyTorch,大概也能明白發生了什麼。在第一部分,我們讓兩種類型的資料經過 D,并對 D 的猜測 vs. 真實标記執行不同的評判标準。這是 “forward” 那一步;随後我們需要 “backward()” 來計算梯度,然後把這用來在 d_optimizer step() 中更新 D 的參數。這裡,G 被使用但尚未被訓練。

在最後的部分,我們對 G 執行同樣的操作——注意我們要讓 G 的輸出穿過 D (這其實是送給造假者一個鑒定專家來練手)。但在這一步,我們并不優化、或者改變 D。我們不想讓鑒定者 D 學習到錯誤的标記。是以,我們隻執行 g_optimizer.step()。

這就是全部了。還有一些其他樣闆代碼,但GAN特定的東西隻是那5個元件,沒有别的了。

在 D 和 G 之間幾千輪交手之後,我們會得到什麼?判别器 D 會快速改進,而 G 的進展要緩慢許多。但當模型達到一定性能之後,G 才有了個配得上的對手,并開始提升,巨幅提升。

兩萬輪訓練之後,G 的輸入平均值超過 4,但會傳回到相當平穩、合理的範圍(左圖)。同樣的,标準差一開始在錯誤的方向降低,但随後攀升至理想中的 1.25 區間(右圖),達到 R 的層次。

【PyTorch】50行代碼實作GAN——PyTorch一、什麼是GAN二、用PyTorch訓練GAN

是以,基礎資料最終會與 R 吻合。那麼,那些比 R 更高的時候呢?資料分布的形狀看起來合理嗎?畢竟,你一定可以得到有 4.0 的平均值和 1.25 标準內插補點的均勻分布,但那不會真的符合 R。我們一起來看看 G 生成的最終分布。

【PyTorch】50行代碼實作GAN——PyTorch一、什麼是GAN二、用PyTorch訓練GAN

結果是不錯的。左側的尾巴比右側長一些,但偏離程度和峰值與原始 Gaussian 十分相近。G 接近完美地再現了原始分布 R——D 落于下風,無法分辨真相和假相。而這就是我們想要得到的結果——使用不到 50 行代碼。

該說的都說完了,老司機請上 GitHub 把玩全套代碼。

位址:https://github.com/devnag/pytorch-generative-adversarial-networks

附所有代碼供參考:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
 
# Data params
data_mean = 4
data_stddev = 1.25
 
# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size
 
d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1
 
# ### Uncomment only one of these
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
 
print("Using data [%s]" % (name))
 
# ##### DATA: Target data and generator input data
 
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian
 
def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian
 
# ##### MODELS: Generator model and discriminator model
 
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
 
    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)
 
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
 
    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))
 
def extract(v):
    return v.data.storage().tolist()
 
def stats(d):
    return [np.mean(d), np.std(d)]
 
def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)
 
d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)
 
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()
 
        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params
 
        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
 
    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()
 
        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine
 
        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters
 
    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))