GANs很难?这篇文章教你50行代码搞定(PyTorch)

2018 年 5 月 12 日 全球人工智能 量子位

高薪招聘兼职AI讲师和AI助教!

加入高端数字货币投资者群!

量子位编译自Medium,作者Dev Nag,数据可视化分析平台Wavefront创始人、CTO,曾是Google、PayPal工程师。

2014年,Ian Goodfellow和他在蒙特利尔大学的同事们发表了一篇令人惊叹的论文,正式把生成对抗网络(GANs)介绍给全世界。通过把计算图和博弈论创新性的结合起来,GANs有能力让两个互相对抗的模型通过反向传播共同训练。

模型中有两个相互对抗的角色,我们分别称为G和D,简单解释如下:G是一个生成器,它试图通过学习真实数据集R,来创建逼真的假数据;D是鉴别器,从R和G处获得数据并标记差异。

Goodfellow有个很好的比喻:G是一个造假团队,试图造出跟真画一样的赝品;D是鉴定专家,试图找出真画和赝品的差异。当然在GANs的设定里,G是一群永远见不到真画的造假团队,他们能够获得的反馈只有D的鉴定意见。

在理想情况下,D和G都会随着时间的推移变得更好,直到G变成一个造假大师,最终让D无法区分出真画和赝品。实际上,Goodfellow已经表明G能够对原始数据集进行无监督学习,并且找到这些数据的低维表达方式。

这么厉害的技术,代码怎么也得一大堆吧?

并不是。使用刚刚发布的PyTorch,实际上可以只用不到50行代码,就能创建一个GAN。我们需要考虑的组件只有下面五个:

R:原始的真实数据集

I:作为熵源输入生成器的随机噪声

G:尝试复制/模仿原始数据集的生成器

D:尝试分辨G输出的鉴别器

训练循环:我们教G造假,再教D来鉴定

1)R: 我们将从最简单的R,一个钟形曲线开始。这个函数以平均值和标准偏差为参数,然后返回一个函数。在我们的示例代码中,使用了平均值4.0和标准差1.25。

2)I: 输入生成器的噪声也是随机的,但是为了增加点难度,我们使用了一个均匀分布,而不是正态分布。这意味着模型G不能简单地通过移动/缩放复制R,而必须以非线性的方式重塑数据。

3)G: 生成器是一个标准的前馈图,包含两个隐藏层,三个线性映射。在这里,我们使用了ELU(指数线性单位)。G将从I获得均匀分布的数据样本,并以某种方式模仿来自R的正态分布样本。

4)D: 鉴别器与生成器G的代码非常相似,都是有两个隐藏层和三个线性映射的前馈图。它将从R或G获取样本,并输出介于0和1之间的单个标量,0和1分别表示“假”和“真”。

5)训练循环 最后,训练循环在两种模式之间交替:首先,用带有准确标签的真实数据和假数据来训练D;然后,训练G来愚弄D。即使你从没用过PyTorch,也大致能看出发生了什么。在上图标为绿色的第一部分,我们将不同类型的数据输入D,并对D的猜测结果和实际的标签进行评判。这一步是“正向”的,然后我们用“反向”来计算梯度,并用它来更新d_optimizer step()调用的D参数。

上面,我们用到了G,但没有训练它。

在标为红色的下半部分中,我们对G做了同样的事情,注意:我们还会通过D来运行G的输出,相当于给了造假者一个侦探练习。但是在这一步中,我们不会对D进行优化或更改,因为我们不希望D学到错误的标签。因此,我们只调用g_optimizer.step()。

就这些啦,还有一些其他的样本代码,但是针对GAN的只有这五个组件。

对D和G进行几千轮训练之后,我们能得到什么?鉴别器D优化得很快,而G一开始优化得比较慢,不过,一旦到达了特定水平,G就开始迅速成长。

两万轮训练过后,G的输出的平均值超过4.0,但随后回到一个相当稳定,正确的范围(如左图)。同样,标准偏差最初在错误的方向下降,但随后上升到所要求的1.25范围(右图),与R相当。

所以,基本的统计最终与R相当,那么高阶矩如何呢?分布的形状是否正确?毕竟,你当然可以有一个平均值为4.0、标准差为1.25的均匀分布,但这不会真正与R相匹配。让我们看看G形成的最终分布。

还不错。左尾比右边稍微长了一点,但是我们可以说,它的偏斜和峰态符合原始的高斯函数。

G几乎完美还原了R的原始分布,而D独自在角落徘徊,无法分清真伪。这正是我们想要的结果。用不到50行的代码,就能实现。

来自:量子位

原文链接:http://baijiahao.baidu.com/s?id=1559201871701762&wfr=spider&for=pc

- 加入AI学院学习 -

点击“ 阅读原文 ”进入学习

登录查看更多
1

相关内容

最新《自动微分手册》77页pdf
专知会员服务
102+阅读 · 2020年6月6日
【CVPR2020】MSG-GAN:用于稳定图像合成的多尺度梯度GAN
专知会员服务
29+阅读 · 2020年4月6日
Transformer文本分类代码
专知会员服务
117+阅读 · 2020年2月3日
必读的10篇 CVPR 2019【生成对抗网络】相关论文和代码
专知会员服务
33+阅读 · 2020年1月10日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
167+阅读 · 2019年10月28日
万字综述之生成对抗网络(GAN)
PaperWeekly
43+阅读 · 2019年3月19日
一文读懂PyTorch张量基础(附代码)
数据派THU
6+阅读 · 2018年6月12日
【干货】深入理解自编码器(附代码实现)
GANs之父Ian Goodfellow力荐:GANs的谱归一化
论智
8+阅读 · 2017年11月25日
深度卷积对抗生成网络(DCGAN)实战
全球人工智能
14+阅读 · 2017年11月7日
GAN完整理论推导、证明与实现(附代码)
数据派THU
5+阅读 · 2017年10月6日
【原理】十个生成模型(GANs)的最佳案例和原理 | 代码+论文
GAN生成式对抗网络
8+阅读 · 2017年8月14日
Augmentation for small object detection
Arxiv
11+阅读 · 2019年2月19日
A Probe into Understanding GAN and VAE models
Arxiv
9+阅读 · 2018年12月13日
Arxiv
11+阅读 · 2018年3月23日
Arxiv
4+阅读 · 2018年3月23日
Arxiv
12+阅读 · 2018年1月12日
VIP会员
相关资讯
万字综述之生成对抗网络(GAN)
PaperWeekly
43+阅读 · 2019年3月19日
一文读懂PyTorch张量基础(附代码)
数据派THU
6+阅读 · 2018年6月12日
【干货】深入理解自编码器(附代码实现)
GANs之父Ian Goodfellow力荐:GANs的谱归一化
论智
8+阅读 · 2017年11月25日
深度卷积对抗生成网络(DCGAN)实战
全球人工智能
14+阅读 · 2017年11月7日
GAN完整理论推导、证明与实现(附代码)
数据派THU
5+阅读 · 2017年10月6日
【原理】十个生成模型(GANs)的最佳案例和原理 | 代码+论文
GAN生成式对抗网络
8+阅读 · 2017年8月14日
相关论文
Augmentation for small object detection
Arxiv
11+阅读 · 2019年2月19日
A Probe into Understanding GAN and VAE models
Arxiv
9+阅读 · 2018年12月13日
Arxiv
11+阅读 · 2018年3月23日
Arxiv
4+阅读 · 2018年3月23日
Arxiv
12+阅读 · 2018年1月12日
Top
微信扫码咨询专知VIP会员