天天看点

GAN完整理论推导与实现,Perfect!

本文是机器之心第二个 GitHub 实现项目,上一个 GitHub 实现项目为从头开始构建卷积神经网络。在本文中,我们将从原论文出发,借助 Goodfellow 在 NIPS 2016 的演讲和台大李弘毅的解释,完成原 GAN 的推导、证明与实现。

本文主要分四部分,第一部分描述 GAN 的直观概念,第二部分描述概念与优化的形式化表达,第三部分将对 GAN 进行详细的理论推导与分析,最后我们将实现前面的理论分析。

  • GitHub项目地址:https://github.com/jiqizhixin/ML-Tutorial-Experiment

本文更注重理论与推导,更多生成对抗网络的概念与应用请参阅:

生成对抗网络基本概念

要理解生成对抗模型(GAN),首先要了解生成对抗模型可以拆分为两个模块:一个是判别模型,另一个是生成模型。简单来说就是:两个人比赛,看是 A 的矛厉害,还是 B 的盾厉害。比如,我们有一些真实数据,同时也有一把随机生成的假数据。A 拼命地把随手拿过来的假数据模仿成真实数据,并揉进真实数据里。B 则拼命地想把真实数据和假数据区分开。

这里,A 就是一个生成模型,类似于造假币的,一个劲地学习如何骗过 B。而 B 则是一个判别模型,类似于稽查警察,一个劲地学习如何分辨出 A 的造假技巧。

如此这般,随着 B 的鉴别技巧越来越厉害,A 的造假技巧也是越来越纯熟,而一个一流的假币制造者就是我们所需要的。虽然 GAN 背后的思想十分直观与朴素,但我们需要更进一步了解该理论背后的证明与推导。

总的来说,Goodfellow 等人提出来的 GAN 是通过对抗过程估计生成模型的新框架。在这种框架下,我们需要同时训练两个模型,即一个能捕获数据分布的生成模型 G 和一个能估计数据来源于真实样本概率的判别模型 D。生成器 G 的训练过程是最大化判别器犯错误的概率,即判别器误以为数据是真实样本而不是生成器生成的假样本。因此,这一框架就对应于两个参与者的极小极大博弈(minimax game)。在所有可能的函数 G 和 D 中,我们可以求出唯一均衡解,即 G 可以生成与训练样本相同的分布,而 D 判断的概率处处为 1/2,这一过程的推导与证明将在后文详细解释。

当模型都为多层感知机时,对抗性建模框架可以最直接地应用。为了学习到生成器在数据 x 上的分布 P_g,我们先定义一个先验的输入噪声变量 P_z(z),然后根据 G(z;θ_g) 将其映射到数据空间中,其中 G 为多层感知机所表征的可微函数。我们同样需要定义第二个多层感知机 D(s;θ_d),它的输出为单个标量。D(x) 表示 x 来源于真实数据而不是 P_g 的概率。我们训练 D 以最大化正确分配真实样本和生成样本的概率,因此我们就可以通过最小化 log(1-D(G(z))) 而同时训练 G。也就是说判别器 D 和生成器对价值函数 V(G,D) 进行了极小极大化博弈:

GAN完整理论推导与实现,Perfect!

我们后一部分会对对抗网络进行理论上的分析,该理论分析本质上可以表明如果 G 和 D 的模型复杂度足够(即在非参数限制下),那么对抗网络就能生成数据分布。此外,Goodfellow 等人在论文中使用如下案例为我们简要介绍了基本概念。

GAN完整理论推导与实现,Perfect!

如上图所示,生成对抗网络会训练并更新判别分布(即 D,蓝色的虚线),更新判别器后就能将数据真实分布(黑点组成的线)从生成分布 P_g(G)(绿色实线)中判别出来。下方的水平线代表采样域 Z,其中等距线表示 Z 中的样本为均匀分布,上方的水平线代表真实数据 X 中的一部分。向上的箭头表示映射 x=G(z) 如何对噪声样本(均匀采样)施加一个不均匀的分布 P_g。(a)考虑在收敛点附近的对抗训练:P_g 和 P_data 已经十分相似,D 是一个局部准确的分类器。(b)在算法内部循环中训练 D 以从数据中判别出真实样本,该循环最终会收敛到 D(x)=p_data(x)/(p_data(x)+p_g(x))。(c)随后固定判别器并训练生成器,在更新 G 之后,D 的梯度会引导 G(z)流向更可能被 D 分类为真实数据的方向。(d)经过若干次训练后,如果 G 和 D 有足够的复杂度,那么它们就会到达一个均衡点。这个时候 p_g=p_data,即生成器的概率密度函数等于真实数据的概率密度函数,也即生成的数据和真实数据是一样的。在均衡点上 D 和 G 都不能得到进一步提升,并且判别器无法判断数据到底是来自真实样本还是伪造的数据,即 D(x)= 1/2。

上面比较精简地介绍了生成对抗网络的基本概念,下一节将会把这些概念形式化,并描述优化的大致过程。

概念与过程的形式化

理论完美的生成器

该算法的目标是令生成器生成与真实数据几乎没有区别的样本,即一个造假一流的 A,就是我们想要的生成模型。数学上,即将随机变量生成为某一种概率分布,也可以说概率密度函数为相等的:P_G(x)=P_data(x)。这正是数学上证明生成器高效性的策略:即定义一个最优化问题,其中最优生成器 G 满足 P_G(x)=P_data(x)。如果我们知道求解的 G 最后会满足该关系,那么我们就可以合理地期望神经网络通过典型的 SGD 训练就能得到最优的 G。

最优化问题

正如最开始我们了解的警察与造假者案例,定义最优化问题的方法就可以由以下两部分组成。首先我们需要定义一个判别器 D 以判别样本是不是从 P_data(x) 分布中取出来的,因此有:

GAN完整理论推导与实现,Perfect!

其中 E 指代取期望。这一项是根据「正类」(即辨别出 x 属于真实数据 data)的对数损失函数而构建的。最大化这一项相当于令判别器 D 在 x 服从于 data 的概率密度时能准确地预测 D(x)=1,即:

另外一项是企图欺骗判别器的生成器 G。该项根据「负类」的对数损失函数而构建,即:

因为 x<1 的对数为负,那么如果最大化该项的值,则需要令均值 D(G(z))≈0,因此 G 并没有欺骗 D。为了结合这两个概念,判别器的目标为最大化:

GAN完整理论推导与实现,Perfect!

给定生成器 G,其代表了判别器 D 正确地识别了真实和伪造数据点。给定一个生成器 G,上式所得出来的最优判别器可以表示为 (下文用 D_G*表示)。定义价值函数为:

GAN完整理论推导与实现,Perfect!

然后我们可以将最优化问题表述为:

现在 G 的目标已经相反了,当 D=D_G*时,最优的 G 为最小化前面的等式。在论文中,作者更喜欢求解最优化价值函的 G 和 D 以求解极小极大博弈:

GAN完整理论推导与实现,Perfect!

对于 D 而言要尽量使公式最大化(识别能力强),而对于 G 又想使之最小(生成的数据接近实际数据)。整个训练是一个迭代过程。其实极小极大化博弈可以分开理解,即在给定 G 的情况下先最大化 V(D,G) 而取 D,然后固定 D,并最小化 V(D,G) 而得到 G。其中,给定 G,最大化 V(D,G) 评估了 P_G 和 P_data 之间的差异或距离。

最后,我们可以将最优化问题表达为:

上文给出了 GAN 概念和优化过程的形式化表达。通过这些表达,我们可以理解整个生成对抗网络的基本过程与优化方法。当然,有了这些概念我们完全可以直接在 GitHub 上找一段 GAN 代码稍加修改并很好地运行它。但如果我们希望更加透彻地理解 GAN,更加全面地理解实现代码,那么我们还需要知道很多推导过程。比如什么时候 D 能令价值函数 V(D,G) 取最大值、G 能令 V(D,G) 取最小值,而 D 和 G 该用什么样的神经网络(或函数),它们的损失函数又需要用什么等等。总之,还有很多理论细节与推导过程需要我们进一步挖掘。

理论推导

在原 GAN 论文中,度量生成分布与真实分布之间差异或距离的方法是 JS 散度,而 JS 散度是我们在推导训练过程中使用 KL 散度所构建出来的。所以这一部分将从理论基础出发再进一步推导最优判别器和生成器所需要满足的条件,最后我们将利用推导结果在数学上重述训练过程。这一部分为我们下一部分理解具体实现提供了强大的理论支持。

KL 散度

在信息论中,我们可以使用香农熵(Shannon entropy)来对整个概率分布中的不确定性总量进行量化:

GAN完整理论推导与实现,Perfect!

如果我们对于同一个随机变量 x 有两个单独的概率分布 P(x) 和 Q(x),我们可以使用 KL 散度(Kullback-Leibler divergence)来衡量这两个分布的差异:

GAN完整理论推导与实现,Perfect!

在离散型变量的情况下,KL 散度衡量的是,当我们使用一种被设计成能够使得概率分布 Q 产生的消息的长度最小的编码,发送包含由概率分布 P 产生的符号的消息时,所需要的额外信息量。

KL 散度有很多有用的性质,最重要的是它是非负的。KL 散度为 0 当且仅当 P 和 Q 在离散型变量的情况下是相同的分布,或者在连续型变量的情况下是 『几乎处处』 相同的。因为 KL 散度是非负的并且衡量的是两个分布之间的差异,它经常 被用作分布之间的某种距离。然而,它并不是真的距离因为它不是对称的:对于某些 P 和 Q,D_KL(P||Q) 不等于 D_KL(Q||P)。这种非对称性意味着选择 D_KL(P||Q) 还是 D_KL(Q||P) 影响很大。

在李弘毅的讲解中,KL 散度可以从极大似然估计中推导而出。若给定一个样本数据的分布 P_data(x) 和生成的数据分布 P_G(x;θ),那么 GAN 希望能找到一组参数θ使分布 P_g(x;θ) 和 P_data(x) 之间的距离最短,也就是找到一组生成器参数而使得生成器能生成十分逼真的图片。

现在我们可以从训练集抽取一组真实图片来训练 P_G(x;θ) 分布中的参数 θ 使其能逼近于真实分布。因此,现在从 P_data(x) 中抽取 m 个真实样本 {x^1,x^2,…,x^