TensorFlow动态图5行代码实现迁移学习 - 识别转变风格的MNIST

2019 年 4 月 26 日 专知

【导读】迁移学习的目的之一是提升机器学习模型的泛化能力。其中,Domain Adaption使得我们可以仅在源域标注数据,就可以获得在目标域上表现较好的模型。本文用TF动态图实现一种经典的Domain Adaption机制 — 梯度反转层。


本文介绍的算法来自ICML 2015的一篇经典论文《Unsupervised Domain Adaptation by Backpropagation》,目前被引用接近300次。其核心是特征领域分类器和梯度反转层。早期的TensorFlow静态图实现梯度反转层非常麻烦,而现在的TensorFlow动态图只需要几行代码就可以构建高可读性的梯度反转层。本文首先简单介绍该算法的原理,然后介绍如何用TensorFlow动态图(TensorFlow Eager模式)来实现该算法。


数据集和任务


Github上的许多实现版本以MNIST数据集为源域,通过为MNIST数据集添加随机背景的方式来自动构建目标域数据集:

为了简化,本文的代码直接使用MNIST数据集的反色版作为目标域数据集:


算法的任务,是希望仅利用有标注的MNIST数据集,和未标注的反转MNIST数据集(虽然有标签,但是没有使用),一起训练模型,使得在反转MNIST数据集上获得较好的分类效果。


算法原理


算法的框架图如下所示:


算法包含三个模块:

  • 编码器(绿色):从原始输入获得高阶特征

  • 分类器(蓝色):最终的分类任务

  • 领域分类器(红色):识别数据来自源域还是目标域


与生成对抗网络(GAN)的思想类似,编码器类似于GAN中的生成器,领域分类器类似于GAN中的判别器。编码器为了生成能够欺骗领域分类器的特征,会将两个领域的数据都编码为相似分布的数据。


与GAN中分阶段分别训练生成器和判别器的方式不同,该算法使用了梯度反转层(Gradient Reversal Layer,GRL)直接训练模型,GRL层在前向传播时等价于Identity层,即输出等于输入,而在反向传播时,GRL的梯度等于传入梯度的负,再乘上一个正系数。因此,GRL层之后(按照前向传播的方向看前后)的梯度是真实的梯度,而GRL之前的梯度是反转的梯度。如果用梯度下降优化模型,GRL之后的参数(领域分类器参数)会朝着降低领域分类损失的方向去优化,而GRL之前的参数(特征编码器参数)会朝着提升领域分类损失的方向去优化,即混淆两个领域编码后的特征分布。


TensorFlow动态图实现


我们将完整源码上传到了Github:

https://github.com/hujunxianligong/TF-GRL


这里我们介绍核心组件GRL的实现,在TensorFlow动态图(Eager)模式中,如果要定义一个可以自定义梯度的实现,要定义一个函数,用@tf.custom_gradient注解。函数的返回值格式为:[前向传播结果,梯度计算方法],其中,梯度计算方法也是一个函数,其接收反向传播过来的梯度,返回前向传播每个输入的梯度。因此,GRL层的实现只需要5行代码:

@tf.custom_gradient
def grl(x, alpha):
def grad(dy):
return -dy * alpha, None
return
x, grad


网络的编码器、分类器、领域分类器等的实现与普通MNIST分类器基本没有区别,完整代码请查看GITHUB链接:

https://github.com/hujunxianligong/TF-GRL


参考链接:

  • http://proceedings.mlr.press/v37/ganin15.pdf


-END-

专 · 知

专知,专业可信的人工智能知识分发,让认知协作更快更好!欢迎登录www.zhuanzhi.ai,注册登录专知,获取更多AI知识资料!

欢迎微信扫一扫加入专知人工智能知识星球群,获取最新AI专业干货知识教程视频资料和与专家交流咨询!

请加专知小助手微信(扫一扫如下二维码添加),加入专知人工智能主题群,咨询技术商务合作~

专知《深度学习:算法到实战》课程全部完成!530+位同学在学习,现在报名,限时优惠!网易云课堂人工智能畅销榜首位!

点击“阅读原文”,了解报名专知《深度学习:算法到实战》课程

登录查看更多
18

相关内容

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。
Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
129+阅读 · 2020年3月15日
《强化学习—使用 Open AI、TensorFlow和Keras实现》174页pdf
专知会员服务
136+阅读 · 2020年3月1日
【经典书】精通机器学习特征工程,中文版,178页pdf
专知会员服务
356+阅读 · 2020年2月15日
【模型泛化教程】标签平滑与Keras, TensorFlow,和深度学习
专知会员服务
20+阅读 · 2019年12月31日
零样本图像分类综述 : 十年进展
专知会员服务
126+阅读 · 2019年11月16日
基于TensorFlow和Keras的图像识别
Python程序员
16+阅读 · 2019年6月24日
博客 | 代码+论文+解析 | 7种常见的迁移学习
AI研习社
8+阅读 · 2019年4月25日
[机器学习] 用KNN识别MNIST手写字符实战
机器学习和数学
4+阅读 · 2018年5月13日
【教程】 在Keras上实现GAN:构建消除图片模糊的应用
GAN生成式对抗网络
4+阅读 · 2018年4月2日
TensorFlow图像分类教程
云栖社区
9+阅读 · 2017年12月29日
tensorflow系列笔记:流程,概念和代码解析
北京思腾合力科技有限公司
30+阅读 · 2017年11月11日
Arxiv
5+阅读 · 2020年3月17日
Multi-Grained Named Entity Recognition
Arxiv
6+阅读 · 2019年6月20日
Arxiv
5+阅读 · 2018年5月1日
Arxiv
6+阅读 · 2018年3月29日
VIP会员
相关VIP内容
Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
129+阅读 · 2020年3月15日
《强化学习—使用 Open AI、TensorFlow和Keras实现》174页pdf
专知会员服务
136+阅读 · 2020年3月1日
【经典书】精通机器学习特征工程,中文版,178页pdf
专知会员服务
356+阅读 · 2020年2月15日
【模型泛化教程】标签平滑与Keras, TensorFlow,和深度学习
专知会员服务
20+阅读 · 2019年12月31日
零样本图像分类综述 : 十年进展
专知会员服务
126+阅读 · 2019年11月16日
相关资讯
基于TensorFlow和Keras的图像识别
Python程序员
16+阅读 · 2019年6月24日
博客 | 代码+论文+解析 | 7种常见的迁移学习
AI研习社
8+阅读 · 2019年4月25日
[机器学习] 用KNN识别MNIST手写字符实战
机器学习和数学
4+阅读 · 2018年5月13日
【教程】 在Keras上实现GAN:构建消除图片模糊的应用
GAN生成式对抗网络
4+阅读 · 2018年4月2日
TensorFlow图像分类教程
云栖社区
9+阅读 · 2017年12月29日
tensorflow系列笔记:流程,概念和代码解析
北京思腾合力科技有限公司
30+阅读 · 2017年11月11日
Top
微信扫码咨询专知VIP会员