天天看點

GANs入門系列之(一)50行代碼實作GANs(PyTorch)

大神的一篇blog,有人翻譯過,但是還不如翻譯軟體,是以自己捋一下,GANs入門必看經典。

GANs入門系列之(一)50行代碼實作GANs(PyTorch)

什麼是 GAN?

In 2014, Ian Goodfellow and his colleaguesat the University of Montreal published a stunning paper introducing the worldto GANs, or generative adversarial networks. Through an innovative combinationof computational graphs and game theory they showed that, given enough modelingpower, two models fighting against each other would be able to co-train throughplain old backpropagation.

2014年,IanGoodfellow和他在蒙特利爾大學的同僚發表了一篇令人驚歎的論文(GANs),提出了GANs(生成式對抗網絡)。 他們通過創新性地組合計算圖和博弈論,展示了給定足夠的模組化能力,兩個互相對抗的模型能夠通過普通的反向傳播進行共同訓練。

The models play two distinct (literally,adversarial) roles. Given some real data set R, G is the generator, trying tocreate fake data that looks just like the genuine data, while D is thediscriminator, getting data from either the real set or G and labeling thedifference. Goodfellow’s metaphor (and a fine one it is) was that G was like ateam of forgers trying to match real paintings with their output, while D wasthe team of detectives trying to tell the difference. (Except that in thiscase, the forgers G never get to see the original data — only the judgments of D.They’relike blind forgers.)

模型扮演了兩個不同的(确切地說,是對抗的)的角色。 給定一些真實資料集R,G是生成器,試圖建立看起來像真實資料的假資料,而D是判别器,從真實資料集或G中獲得資料并标記差異。 Goodfellow給了一個很貼切的比喻,G像一夥努力用他們的輸出比對真實圖畫的僞造者,而D是一幫努力鑒别差異的偵探。(唯一的不同是,僞造者G永遠不會看到原始資料 –而隻能看到D的判斷。他們是一夥盲人騙子)。

GANs入門系列之(一)50行代碼實作GANs(PyTorch)

In the ideal case, both D and G would getbetter over time until G had essentially become a “master forger” of thegenuine article and D was at a loss, “unable to differentiate between the twodistributions.”

 理想狀态下,D和G将随着時間的推移而變得更好,直到G真正變成了原始資料的“僞造大師”,而D則徹底迷失,“無法分辨真假”。

In practice, what Goodfellow had shown was thatG would be able to perform a form of unsupervised learning on the originaldataset, finding some way of representing that data in a (possibly) muchlower-dimensional manner. And as Yann LeCun famously stated, unsupervisedlearning is the “cake” of true AI.

實際上,Goodfellow已經指出,G将能夠對原始資料集進行無監督學習,找到某種(可能)次元低得多的方式來表示該資料。就像Yann LeCun所說,無監督學習是the “cake” of true AI。

用 PyTorch 訓練 GAN

This powerful technique seems like it mustrequire a metric ton of code just to get started, right? Nope. Using PyTorch,we can actually create a very simple GAN in under 50 lines of code. There arereally only 5 components to think about:

這種強大的技術似乎需要大量的代碼才可以,對吧?并不是。 使用PyTorch,我們實際上可以在50行代碼下建立一個非常簡單的GAN。真的隻需要考慮5個元件:

R: The original, genuine data set

I: The random noise that goes into thegenerator as a source of entropy

G: The generator which tries to copy/mimicthe original data set

D: The discriminator which tries to tellapart G’s output from R

The actual ‘training’ loop where we teach Gto trick D and D to beware G.

R:原始的、真正的資料;

I:進入生成器作為熵源的随機噪聲;

G:努力模仿原始資料的生成器;

D:努力将G從R中分辨出來的判别器;

訓練循環,我們在其中教G來愚弄D,教D小心G。

1) R: In our case, we’ll start with thesimplest possible R — a bell curve. This function takes a mean and a standard deviationand returns a function which provides the right shape of sample data from aGaussian with those parameters. In our sample code, we’ll use a mean of 4.0 anda standard deviation of 1.25.

1)R:在我們的例子中,我們将從最簡單的R- 一個鐘形曲線開始。 鐘形函數采用均值和标準差,并傳回一個函數,該函數提供了使用這些參數的高斯分布的正确形狀的樣本資料。在我們的示例代碼中,我們将使用均值4.0和标準差1.25。

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))) # Gaussian
           

2.) I: The input into the generator is alsorandom, but to make our job a little bit harder, let’s use a uniform distributionrather than a normal one. This means that our model G can’t simply shift/scalethe input to copy R, but has to reshape the data in a non-linear way.

2.)I:進入生成器的輸入也是随機的,但是為了使我們的工作更難一點,讓我們使用一個均勻分布,而不是一個普通的分布。這意味着我們的模型G不能簡單地移動/縮放輸入以複制R,而是必須以非線性方式重塑資料。

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n) # Uniform-dist data into generator, _NOT_ Gaussian
           

3.) G: The generator is a standardfeedforward graph — two hidden layers, three linear maps. We’re using an ELU(exponential linear unit). G is going to get the uniformly distributed datasamples from I and somehow mimic the normally distributed samples from R.

3.)G:生成器是一個标準的前饋網絡 - 兩個隐藏層,三個線性映射。我們使用ELU(exponential linear unit,ELU)。 G将從I獲得均勻分布的資料樣本,并以某種方式模仿來自R的正态分布樣本。

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)
           

4.) D: The discriminator code is verysimilar to G’s generator code; a feedforward graph with two hidden layers andthree linear maps. It’s going to get samples from either R or G and will outputa single scalar between 0 and 1, interpreted as ‘fake’ vs. ‘real’.

4.)D:鑒别器代碼與生成器G的代碼非常相似;具有兩個隐藏層和三個線性映射的前饋網絡。 它将從R或G擷取樣本,并将輸出介于0和1之間的單個标量,解釋為“假”與“真”。

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))
           

5.) Finally, the training loop alternatesbetween two modes: first training D on real data vs. fake data, with accuratelabels; and then training G to fool D, with inaccurate labels.

5.) 最後,訓練在兩種模式之間循環交替:首先在真實資料與假資料上用準确的标簽訓練D,; 然後用不準确的标簽訓練G來愚弄D。

for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step() # Only optimizes G's parameters
           

Even if you haven’t seen PyTorch before,you can probably tell what’s going on. In the first (green) section, we pushboth types of data through D and apply a differentiable criterion to D’sguesses vs. the actual labels. That pushing is the ‘forward’ step; we then call‘backward()’ explicitly in order to calculate gradients, which are then used toupdate D’s parameters in the d_optimizer step() call. G is used but isn’ttrained here.

即使你以前沒有見過PyTorch,你也可以知道發生了什麼。在第一部分中,我們将兩種類型的資料都傳送到D,并對D的猜測和實際标簽使用可區分的标準。這種傳送是“前向”的步驟; 我們然後顯式地調用'backward()',以便計算梯度,這會用于更新d_optimizer.step()調用中的D的參數。 我們在這裡使用了G,但沒有訓練。

Then in the last (red) section, we do thesame thing for G — note that we also run G’s output through D (we’re essentially giving theforger a detective to practice on) but we do not optimize or change D at thisstep. We don’twant the detective D to learn the wrong labels. Hence, we only callg_optimizer.step().

然後在最後一個部分,我們為G做同樣的事情- 注意,我們還通過D運作G的輸出(我們基本上是給了騙子一個偵探來讓他練手),但在這一步我們不優化或改變D。 我們不想讓偵探D學習錯誤的标簽。 是以,我們隻調用g_optimizer.step()。

And…that’s all. There’s some otherboilerplate code but the GAN-specific stuff is just those 5 components, nothingelse.

這就是全部了。還有一些其他樣闆代碼,但GAN特定的東西隻是那5個元件,沒有别的了。

After a few thousand rounds of thisforbidden dance between D and G, what do we get? The discriminator D gets goodvery quickly (while G slowly moves up), but once it gets to a certain level ofpower, G has a worthy adversary and begins to improve. Really improve.

在D和G之間幾千次的對抗訓練中,我們得到什麼?鑒别器D很快變優(而G緩慢進步着),但一旦達到某種程度,G就有了一個比對的對手,并開始改善。 真的改善。

Over 20,000 training rounds, the mean ofG’s output overshoots 4.0 but then comes back in a fairly stable, correct range(left). Likewise, the standard deviation initially drops in the wrong directionbut then rises up to the desired 1.25 range (right), matching R.

20,000多個訓練輪次之後,G輸出平均值超過4.0,但随後回到一個相當穩定、正确的範圍(下圖左)。 同樣,标準偏差最初錯誤的下降,但随後上升到我們希望的1.25的範圍(下圖右),比對了R.

GANs入門系列之(一)50行代碼實作GANs(PyTorch)

Ok, so the basic stats match R, eventually.How about the higher moments? Does the shape of the distribution look right?After all, you could certainly have a uniform distribution with a mean of 4.0and a standard deviation of 1.25, but that wouldn’t really match R. Let’s showthe final distribution emitted by G.

好,現在基本的統計和R比對了。 那些highermoments怎麼辦? 分布的形狀看上去正确嗎? 畢竟,你當然可以有一個均值分布,平均值為4.0,标準差為1.25,但那并不會真正地和R比對。讓我們看看G最終發出的分布。

GANs入門系列之(一)50行代碼實作GANs(PyTorch)

Not bad. The left tail is a bit longer thanthe right, but the skew and kurtosis are, shall we say, evocative of theoriginal Gaussian.

真不賴。 左尾比右邊有點長,但我們應該說,偏斜和峭度是原始高斯的回歸。

G recovers the original distribution Rnearly perfectly — and D is left cowering in the corner, mumbling to itself, unable totell fact from fiction. This is precisely the behavior we want (see Figure 1 inGoodfellow). From fewer than 50 lines of code.

G幾乎完全重制了原來的分布R,D則暗自神傷,因為他已無法分辨事實和虛幻。 這正是我們想要的結果(見Goodfellow中的圖1)。 隻用了不到50行的代碼。

Goodfellow would go on to publish manyother papers on GANs, including a 2016 gem describing some practicalimprovements, including the mini-batch discrimination method adapted here. Andhere’s a 2-hour tutorial he presented at NIPS 2016. For TensorFlow users, here’sa parallel post from Aylien on GANs.

Goodfellow繼續就GAN的問題發表了許多文章,其中包括一篇2016年的瑰寶(點選打開連結),描述了一些實用的改進, 其中包括了此處适用的mini-batch discrimination方法。這裡有一個2小時的教程(點選打開連結),是他在2016年的NIPS提出的。對于Tensorflow的使用者來說,這裡有一個parallel post(點選打開連結),來自GANs的Aylien。

參考資料:

1.Blog

2.Code

3.http://www.sohu.com/a/126742829_473283

4.http://www.pytorchtutorial.com/50-lines-of-codes-for-gan/#_GAN

5.https://blog.csdn.net/xjc864588399/article/details/56289591

繼續閱讀