GAN,生成对抗网络,了解这种方法需要了解以下概念。
浅层的理解:
生成对抗网络一般由G和D两个网络组成,有点像几年前的一个目标跟踪算法TLD算法,以生成图片为例,G网络负责生成新的图片,D网络负责分辨D生成的图片和真正dataset里面的图片,训练的目的是让G生成的图片足够逼真以至于D网络不能分辨真假,即分辨正确率接近50%。论文中给出的例子类似于一个印刷假钞团伙和打假团队互相训练指导印假钞团伙印出的假钞足以以假乱真。
一个生动的应用例子:
这篇知乎文章通俗易懂地解释了GAN并给出了一个自己生成卡通人物头像的demo,从爬数据到生成结果比较完善。
具体来说就是先输入一张随机生成的初始图,一般是高斯噪声,然后经过多个iteration,比如经过一个epoch之后的结果:
最后得到一个令D难以分辨真假的图,如:
这里输入的初始图像是噪声图,那么如果输入的低像素(LR)图,输出高分辨率(HR)图,也是类似的道理。比如SRGAN。
深入探究原理
abstract :
和上面的内容主要一致,需要注意的是There is no need for any Markov chains or unrolled approximate inference networks during either training or generation of samples.按照是否需要定义概率密度函数可以分为Implicit density和Explicit density。前者表示不需要定义概率密度函数如马尔科夫链模型和这里的GAN。
不同的sampling:
原始采样法ancestral sampling:
①将样本按照从小到大的顺序排序,保证每一个样本的父样本不会比当前样本更大
②从最小的样本开始从分布p(x)中提取该样本的概率p(x1), 然后一次取后面样本的条件概率,最后相乘。如上图所示。
采用原始采样法的算法如:GAN和VAE
迭代采样iterative sampling,包括多次pass over所有的变量,迭代地提高样本的质量。通常这种处理包括在整个空间内模拟一个马尔科夫链,通常是非定向模型的选择。
Adversarial nets
定义一个先验的噪声变量pz(z),然后映射到数据空间通过G(z;θg),这里G是一个通过多重感知机通过参数theta g。然后定义另一个多层感知机D(x;θd)输出一个单一的尺度(我理解为得到一个常数也就是判断的结果)。D(x)代表x是从数据集而不是G(z;θg)也就是pz(z)的概率。我们训练D来最大化分辨正确的概率,同步训练G来最小化log(1 -D(G(z)))。也就是我们常见的下面的公式:
如上图所示:x表示数据集空间分布,z表示模型给出的分布,黑点表示数据集中采样数据,绿线表示G得到的参数给出的结果分布,蓝点线表示D来区分黑点和绿线之间的分类,a中是一开始的时候x和z差很大,D可以很好地将两个空间分开,为了降低D的正确率G逐渐靠近黑点,导致后面蓝色的点线不能很好地分类,变成一条直线。这是一种形象的对于对抗网络过程的表示。
这个是文章中给出的算法伪代码,可以看到在训练的迭代过程中,从噪声先验Pg(z)中采样m个再从数据集采样m个(这里采样用到的也就是上面的ancestral sampling)上面红色框出来的部分可以看到,D在更新的时候方向是“ascending”梯度,也就是向着loss function上升变大的方向,增大判别的正确率;而G则“descending”并且不是前面的function,而是只考虑了使D的正确率下降就好了。
后面是算法的细节
比如知道了D优化的方向之后怎么求更新后的D:
证明略,也就是求一个超越方程的极值。确定了·D的最优解之后优化G:
当且仅当Pg=Pdata的时候有极值,为-log4,推导过程略,最后得到
Experiments
原作者的实验中主要用到了MNIST、TFD、CIFAR-10数据集,G网络使用混合的整流线性激活函数和sigmoid激活函数;G网络是用macout激活。
训练D网络的过程中用到了Dropout。
测试的时候通过在测试机里面加上一个Gaussian Parzen window。细节实现比如sigma的确认和likelihood的计算略。
Advantages and disadvantages
劣:没有pg(x)的明确定义;D和G在训练的时候必须是同步的,而且为了避免“the Helvetica scenario” G不能比D训练过头太多。
优:对抗模型可能会从生成网络里面获取一些统计上的又是,因为生成网络不是直接由数据及给出的,而是通过D网络得到的。也就是说输入的参数没有直接被复制到生成网络的参数中。
另外作者还给出了生成对抗网络可能的扩展,对后面的工作有指导作用。