生成对抗网络原理与实战

原创:李孟启

1. 前言

在生成对抗网络(Generative Adversarial Network,简称 GAN)发明之前,变分自编码器被认为是理论完备,实现简单,使用神经网络训练起来很稳定,生成的图片逼近度也较高,但是人眼还是可以很轻易地分辨出真实图片与机器生成的图片。

2014 年,Université de Montréal 大学 Yoshua Bengio(2019 年图灵奖获得者)的学生 Ian Goodfellow 提出了生成对抗网络 GAN,从而开辟了深度学习最炙手可热的研究方向之一。从 2014 年到 2019 年,GAN 的研究稳步推进,研究捷报频传,最新的 GAN 算法在图片生成上的效果甚至达到了肉眼难辨的程度,着实令人振奋。由于 GAN 的发明,Ian Goodfellow 荣获 GAN 之父称号,并获得 2017 年麻省理工科技评论颁发的 35 Innovators Under 35 奖项。图 1 展示了从 2014 年到 2018 年,GAN 模型取得了图书生成的效果,可以看到不管是图片大小,还是图片逼真度,都有了巨大的提升。

图1 GAN模型2014~2018年的图片生成效果

2. 博弈学实例

接下来,我们将从生活中博弈学习的实例出发,一步步引出 GAN 算法的设计思想和模型结构。我们用一个漫画家的成长轨迹来形象介绍生成对抗网络的思想。考虑一对双胞胎兄弟,分别称为老二 G 和老大 D,G 学习如何绘制漫画,D 学习如何鉴赏画作。还在娃娃时代的两兄弟,尚且只学会了如何使用画笔和纸张,G 绘制了一张不明所以的画作,如图2(a)所示,由于此时 D 鉴别能力不高,觉得 G 的作品还行,但是人物主体不够鲜明。在 D 的指引和鼓励下,G 开始尝试学习如何绘制主体轮廓和使用简单的色彩搭配。一年后,G 提升了绘画的基本功,D 也通过分析名作和初学者 G 的作品,初步掌握了鉴别作品的能力。此时 D 觉得 G 的作品人物主体有了,如图 2(b),但是色彩的运用还不够成熟。数年后,G 的绘画基本功已经很扎实了,可以轻松绘制出主体鲜明、颜色搭配合适和逼真度较高的画作,如图 2(c),但是 D 同样通过观察 G 和其它名作的差别,提升了画作鉴别能力,觉得 G 的画作技艺已经趋于成熟,但是对生活的观察尚且不够,作品没有传达神情且部分细节不够完美。又过了数年,G 的绘画功力达到了炉火纯青的地步,绘制的作品细节完美、风格迥异、惟妙惟肖,宛如大师级水准,如图 2(d),即便此时的D 鉴别功力也相当出色,亦很难将 G 和其他大师级的作品区分开来。

上述画家的成长历程其实是一个生活中普遍存在的学习过程,通过双方的博弈学习,相互提高,最终达到一个平衡点。GAN 网络借鉴了博弈学习的思想,分别设立了两个子网络:负责生成样本的生成器 G 和负责鉴别真伪的鉴别器 D。类比到画家的例子,生成器 G就是老二,鉴别器 D 就是老大。鉴别器 D 通过观察真实的样本和生成器 G 产生的样本之间的区别,学会如何鉴别真假,其中真实的样本为真,生成器 G 产生的样本为假。而生成器 G 同样也在学习,它希望产生的样本能够获得鉴别器 D 的认可,即在鉴别器 D 中鉴别为真,因此生成器 G 通过优化自身的参数,尝试使得自己产生的样本在鉴别器 D 中判别为真。生成器 G 和鉴别器 D 相互博弈,共同提升,直至达到平衡点。此时生成器 G 生成的样本非常逼真,使得鉴别器 D 真假难分。

图2 画家的成长轨迹示意图

在原始的 GAN 论文中,Ian Goodfellow 使用了另一个形象的比喻来介绍 GAN 模型:生成器网络 G 的功能就是产生一系列非常逼真的假钞试图欺骗鉴别器 D,而鉴别器 D 通过学习真钞和生成器 G 生成的假钞来掌握钞票的鉴别方法。这两个网络在相互博弈的过程中间同步提升,直到生成器 G 产生的假钞非常的逼真,连鉴别器 D 都真假难辨。

这种博弈学习的思想使得 GAN 的网络结构和训练过程与之前的网络模型略有不同,下面我们来详细介绍 GAN 的网络结构和算法原理。

3. GAN原理

一个典型的生成对抗网络模型大概如图3所示。

图3 对抗生成网络模型

我们先来理解下GAN的两个模型要做什么。首先判别模型(鉴别器),就是图3中右半部分的网络,直观来看就是一个简单的神经网络结构,输入就是一副图像,输出就是一个概率值(其实是个二分类问题),用于判断真假使用(概率值大于0.5那就是真,小于0.5那就是假),真假也不过是人们定义的概率而已。其次是生成模型,生成模型要做什么呢,同样也可以看成是一个神经网络模型,输入是一组随机数Z,输出是一个图像,不再是一个数值而已。从图3中可以看到,会存在两个数据集,一个是真实数据集,另一个是假的数据集,那这个数据集就是有生成网络造出来的数据集。根据图3我们再来理解一下GAN的目标是要做什么:

判别网络的目的:就是能判别出来输入的一张图它是来自真实样本集还是假样本集。假如输入的是真样本,网络输出就接近1,输入的是假样本,网络输出接近0,那么很完美,达到了很好判别的目的。

生成网络的目的:生成网络是造样本的,它的目的就是使得自己造样本的能力尽可能强,强到什么程度呢,你判别网络没法判断我是真样本还是假样本。因此辨别网络的作用就是对噪音生成的数据辨别他为假的,对真实的数据辨别他为真的。而生成网络的损失函数就是使得对于噪音数据,经过辨别网络之后的辨别结果是真的,这样就能达到生成真实图像的目的。这里会感觉比较饶,这也是生成对抗网络的难点所在,理解了这点,整个生成对抗网络模型也就理解了。

4. DCGAN实战

这里我们拿DCGAN来举例子,DCGAN是GAN的一个变体,DCGAN就是将CNN和原始的GAN结合到一起,生成网络和鉴别网络都运用到了深度卷积神经网络。DCGAN提高了基础GAN的稳定性和生成结果质量。

该项目使用的是mnist手写字数据集,深度学习框架为tensorflow。你也可以直接跳过下面代码直接git clone本项目,项目的github链接https://github.com/limengqigithub/DCGAN-mnist-master.git

4.1 DCGAN模型代码

import tensorflow as tf
from tensorflow import keras

# 生成网络
class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        self.n_f = 512
        self.n_k = 4

        # input z vector is [None, 100]
        self.dense1 = keras.layers.Dense(3 * 3 * self.n_f)
        self.conv2 = keras.layers.Conv2DTranspose(self.n_f // 2, 3, 2, 'valid')
        self.bn2 = keras.layers.BatchNormalization()
        self.conv3 = keras.layers.Conv2DTranspose(self.n_f // 4, self.n_k, 2, 'same')
        self.bn3 = keras.layers.BatchNormalization()
        self.conv4 = keras.layers.Conv2DTranspose(1, self.n_k, 2, 'same')
        return

    def call(self, inputs, training=None):
        # [b, 100] => [b, 3, 3, 512]
        x = tf.nn.leaky_relu(tf.reshape(self.dense1(inputs), shape=[-1, 3, 3, self.n_f]))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
        x = tf.tanh(self.conv4(x))
        return x

# 判别网络
class Discriminator(keras.Model):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.n_f = 64
        self.n_k = 4

        # input image is [-1, 28, 28, 1]
        self.conv1 = keras.layers.Conv2D(self.n_f, self.n_k, 2, 'same')
        self.conv2 = keras.layers.Conv2D(self.n_f * 2, self.n_k, 2, 'same')
        self.bn2 = keras.layers.BatchNormalization()
        self.conv3 = keras.layers.Conv2D(self.n_f * 4, self.n_k, 2, 'same')
        self.bn3 = keras.layers.BatchNormalization()
        self.flatten4 = keras.layers.Flatten()
        self.dense4 = keras.layers.Dense(1)
        return

    def call(self, inputs, training=None):
        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
        x = self.dense4(self.flatten4(x))
        return x

4.2 损失函数实现

# shorten sigmoid cross entropy loss calculation
def celoss_ones(logits, smooth=0.0):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                  labels=tf.ones_like(logits) * (1.0 - smooth)))


def celoss_zeros(logits, smooth=0.0):

    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                                  labels=tf.zeros_like(logits) * (1.0 - smooth)))


def d_loss_fn(generator, discriminator, input_noise, real_image, is_trainig):
    # 生成模型根据噪声输入生成图片,把根据噪声生成的图片与真实的图片输进纠错模型中,然后做交叉熵的loss计算
    fake_image = generator(input_noise, is_trainig)
    d_real_logits = discriminator(real_image, is_trainig)
    d_fake_logits = discriminator(fake_image, is_trainig)

    d_loss_real = celoss_ones(d_real_logits, smooth=0.1)
    d_loss_fake = celoss_zeros(d_fake_logits, smooth=0.0)
    loss = d_loss_real + d_loss_fake
    return loss


def g_loss_fn(generator, discriminator, input_noise, is_trainig):
    fake_image = generator(input_noise, is_trainig)
    d_fake_logits = discriminator(fake_image, is_trainig)
    loss = celoss_ones(d_fake_logits, smooth=0.1)
    return loss

4.3 保存生成网络的生成结果

def save_result(val_out, val_block_size, image_fn, color_mode):
    def preprocess(img):
        img = ((img + 1.0) * 127.5).astype(np.uint8)
        return img

    preprocesed = preprocess(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocesed[b, :, :, :]
        else:
            single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)

        # concat image row to final_image
        if (b + 1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    Image.fromarray(final_image, mode=color_mode).save(image_fn)

4.4 主函数部分

def main():
    tf.random.set_seed(22)
    np.random.seed(22)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    assert tf.__version__.startswith('2.')

    # hyper parameters  超参数
    z_dim = 100
    epochs = 3000000
    # epochs = 30
    batch_size = 128
    learning_rate = 0.0002
    is_training = True

    # for validation purpose
    assets_dir = './images'
    if not os.path.isdir(assets_dir):
        os.makedirs(assets_dir)
    val_block_size = 10
    val_size = val_block_size * val_block_size

    # load mnist data
    # x_train shape (60000, 28, 28) numpy.ndarray
    # x_test shape (10000, 28, 28) numpy.ndarray
    (x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
    x_train = x_train.astype(np.float32) / 255.  # 归一到(0,1)区间内
    db = tf.data.Dataset.from_tensor_slices(x_train).shuffle(batch_size * 4).batch(batch_size).repeat()
    db_iter = iter(db)
    inputs_shape = [-1, 28, 28, 1]

    # create generator & discriminator
    generator = Generator()
    generator.build(input_shape=(batch_size, z_dim))
    generator.summary()
    discriminator = Discriminator()
    discriminator.build(input_shape=(batch_size, 28, 28, 1))
    discriminator.summary()
    # prepare optimizer
    d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

    for epoch in range(epochs):

        # no need labels
        batch_x = next(db_iter)

        # rescale images to -1 ~ 1
        batch_x = tf.reshape(batch_x, shape=inputs_shape)
        # -1 - 1
        batch_x = batch_x * 2.0 - 1.0

        # Sample random noise for G
        batch_z = tf.random.uniform(shape=[batch_size, z_dim], minval=-1., maxval=1.)

        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))

        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd loss:', float(d_loss), 'g loss:', float(g_loss))

            # validation results at every epoch
            val_z = np.random.uniform(-1, 1, size=(val_size, z_dim))
            fake_image = generator(val_z, training=False)
            image_fn = os.path.join('images', 'gan-val-{:03d}.png'.format(epoch + 1))
            save_result(fake_image.numpy(), val_block_size, image_fn, color_mode='L')

4.5 生成图片结果展示

图4

参考文献:

1.《TensorFlow深度学习》——深入理解人工智能算法设计

2.https://my.oschina.net/u/778683/blog/3100336

3.https://github.com/limengqigithub/DCGAN-mnist-master.git

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,496评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,407评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,632评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,180评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,198评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,165评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,052评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,910评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,324评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,542评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,711评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,424评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,017评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,668评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,823评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,722评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,611评论 2 353

推荐阅读更多精彩内容