【干货】动手实践:理解和优化GAN(附代码)

【导读】本文是机器学习研究员Mirantha Jayathilaka撰写的一篇技术博文,主要讲解了生成对抗网络(GAN)。本文分别从理论和代码实践两方面来介绍GAN,首先介绍了生成器和判别器的概念及其工作原理,然后分别从构建生成器模型,构建判别器模型,选择损失并训练,选择不同的优化算法进行训练等方面讲解代码,文末附有作者的完整代码和数据集代码,感兴趣的读者可以学习一下。


Understanding and optimizing GANs

 (Going back to first principles)

理解和优化GAN

自从Ian Goodfellow首次引入架构以来,生成对抗网络(GAN)的研究一直在不断增加,许多相关的进步和应用变得越来越引人入胜。但对于任何想要开始使用GAN的人来说,如何开始是非常棘手的。这篇文章将引导你如何开始使用。


与许多事情一样,充分理解它的概念的最好方法就是溯源。对于GAN,这里是原论文 - >(https://arxiv.org/abs/1406.2661)。现在理解这类论文可以有两种方法,理论和实践。我通常喜欢后者,但是如果你想深入挖掘数学理解,这是一篇很棒的文章。同时,这篇文章将以Keras的最纯粹的形式介绍GAN的一种简单的算法实现。让我们开始吧。


数学理解GAN: 

https://medium.com/@samramasinghe/generative-adversarial-networks-a-theoretical-walk-through-5889d5a8f2bb


在GAN的基础设置中,有两个模型,即生成器和判别器,其中生成器不断与判别器竞争,判别器区分模型分布(例如生成的假图像)和数据的分布(例如真实图像)的差别。这个概念可以通过著名的伪造者与警察场景来形象化,其中生成模型被认为是伪造者生产假币,判别器模型作为试图找出假币的警察。这个想法是,由于彼此之间不断的竞争,造假者和警察都提升了自己的业务水平,但最终造假者实现了生产假币和真币一样的水平。原理很简短,现在让我们把它放到代码中。


本文提供的示例脚本用于生成伪造的脸部图像。图1显示了我们试图用算法实现的最终结果。

构建生成模型




生成模型应该会吸收一些噪音并输出令人满意的外观图像。在这里,我们使用Keras Sequential模型以及Dense(全连接)和Batch Normalization层。使用的Activation(激活函数)是Leaky Relu。请参阅下面的代码片段。生成模型可以分成几个区块。一个块由Dense层 - >激活 - >Batch Normalization组成。添加了三个这样的块,最后一个块将像素转换为我们期望的图像的期望形状作为输出。模型的输入将是形状(100,)的噪声矢量,并在最后输出模型。注意每个Dense层中的节点随着模型的进展而增加。

def build_generator(self):
noise_shape = (100,)

model = Sequential()

model.add(Dense(256, input_shape=noise_shape))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))

model.summary()

noise = Input(shape=noise_shape)
img = model(noise)

return Model(noise, img)


建立判别器模型



    

判别器接收图像的输入,将其平滑并通过两个Dense- >Activation块,最终输出1和0之间的标量。输出1应表示输入图像是真实的,否则为0。 请参阅下面的代码。

def build_discriminator(self):
img_shape = (self.img_rows, self.img_cols, self.channels)

model = Sequential()

model.add(Flatten(input_shape=img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()

img = Input(shape=img_shape)
validity = model(img)

return Model(img, validity)


注意 您可以修改这些模型,以获得更多的块,更多的Batch Normalization层,不同的激活函数等。按照这个例子,这些模型足以理解GAN背后的概念。

找出损失并训练




我们计算三中损失,在这个例子中全部使用二分类交叉熵来训练这两个模型。

首先是判别器。 如下面的代码所示,它训练了两种方式。 首先为真实图像输出1(数组'img'),然后为生成的图像输出0(数组'gen_img')。 随着训练的进展,辨别器在此任务中得到改进。 但是我们的最终目标是在鉴别器对两种输入类型输出0.5的理论点(无法判断真假)。

d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))


接下来是训练生成模型,这是棘手的一点。 要做到这一点,我们首先将生成器模型和判别器模型组合起来,用判别器的输出处理生成器模型的输出。 记得! 理想情况下,我们希望这是1,这意味着鉴别器将假造图像识别为真实图像。请参阅下面的代码。

z = Input(shape=(100,))
img = self.generator(z)

valid = self.discriminator(img)

self.combined = Model(z, valid)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

g_loss = self.combined.train_on_batch(noise, valid_y)


现在让我们来调教代码吧




这是代码的要点,以简单理解GAN的运作。


完整的代码可以在GitHub上找到。

https://github.com/miranthajayatilake/GANwKeras


您可以参考所有用于导入RGB图像的附加代码,初始化模型并将结果记录在代码中。请注意,在训练期间,为了能够在CPU上运行,将Mini batches设置为Hi32映像。 此外,本例中使用的真实图像是来自CelebA数据集的5000张图像。这是一个开源数据集,我已经将它上传到Floydhub,以便下载,您可以在这里找到。 

https://www.floydhub.com/mirantha/datasets/celeba


有很多方法可以优化代码以获得更好的结果,并且这样可以帮助你了解算法的不同组件如何影响模型。在调整优化器,激活函数,归一化,损失损失函数,超参数等不同组件的同时观察结果是增强对算法理解的最佳方法。这里 我选择改变优化器。


因此,用32 batches训练5000 epochs,我使用三种优化算法进行了测试。使用Keras这个过程就像导入和替换优化器函数的名称一样简单。 Keras内置的所有优化器都可以在这里找到。 


此外,在每个实例中绘制的损失用于理解模型的行为。


1. 使用SGD(随机梯度下降优化器)。输出和损失变化分别如图2和3所示。

注意:虽然收敛是不平稳的,但我们可以在这里看到,生成器损失在epochs时期减少,这意味着鉴别器倾向于将假图像检测为真实。


2.使用RMSProp优化器。 输出和损失变化分别如图4和5所示。

损失:

注意:在这里,我们也看到生成模型损失在减少,这是一件好事。 令人惊讶的是,真实图像上的判别器损失增加,这非常有趣。


3. 使用Adam优化器。 输出和损耗变化分别如图6和图7所示。


注意: adam优化器产生迄今为止最好的结果。 请注意,假图像上的鉴别器损失保留较大的值,这意味着鉴别器倾向于将假图像检测为真实。

 

完整代码:

https://github.com/miranthajayatilake/GANwKeras


图像数据:

https://www.floydhub.com/mirantha/datasets/celeba


原文链接:

https://towardsdatascience.com/understanding-and-optimizing-gans-going-back-to-first-principles-e5df8835ae18

-END-

专 · 知

人工智能领域主题知识资料查看获取【专知荟萃】人工智能领域26个主题知识资料全集(入门/进阶/论文/综述/视频/专家等)

请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料

请扫一扫如下二维码关注我们的公众号,获取人工智能的专业知识!

请加专知小助手微信(Rancho_Fang),加入专知主题人工智能群交流!加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~


点击“阅读原文”,使用专知

展开全文
Top
微信扫码咨询专知VIP会员