TensorFlow动态图5行代码实现迁移学习 - 识别转变风格的MNIST

【导读】迁移学习的目的之一是提升机器学习模型的泛化能力。其中,Domain Adaption使得我们可以仅在源域标注数据,就可以获得在目标域上表现较好的模型。本文用TF动态图实现一种经典的Domain Adaption机制 — 梯度反转层。


本文介绍的算法来自ICML 2015的一篇经典论文《Unsupervised Domain Adaptation by Backpropagation》,目前被引用接近300次。其核心是特征领域分类器和梯度反转层。早期的TensorFlow静态图实现梯度反转层非常麻烦,而现在的TensorFlow动态图只需要几行代码就可以构建高可读性的梯度反转层。本文首先简单介绍该算法的原理,然后介绍如何用TensorFlow动态图(TensorFlow Eager模式)来实现该算法。


数据集和任务


Github上的许多实现版本以MNIST数据集为源域,通过为MNIST数据集添加随机背景的方式来自动构建目标域数据集:

为了简化,本文的代码直接使用MNIST数据集的反色版作为目标域数据集:


算法的任务,是希望仅利用有标注的MNIST数据集,和未标注的反转MNIST数据集(虽然有标签,但是没有使用),一起训练模型,使得在反转MNIST数据集上获得较好的分类效果。


算法原理


算法的框架图如下所示:


算法包含三个模块:

  • 编码器(绿色):从原始输入获得高阶特征

  • 分类器(蓝色):最终的分类任务

  • 领域分类器(红色):识别数据来自源域还是目标域


与生成对抗网络(GAN)的思想类似,编码器类似于GAN中的生成器,领域分类器类似于GAN中的判别器。编码器为了生成能够欺骗领域分类器的特征,会将两个领域的数据都编码为相似分布的数据。


与GAN中分阶段分别训练生成器和判别器的方式不同,该算法使用了梯度反转层(Gradient Reversal Layer,GRL)直接训练模型,GRL层在前向传播时等价于Identity层,即输出等于输入,而在反向传播时,GRL的梯度等于传入梯度的负,再乘上一个正系数。因此,GRL层之后(按照前向传播的方向看前后)的梯度是真实的梯度,而GRL之前的梯度是反转的梯度。如果用梯度下降优化模型,GRL之后的参数(领域分类器参数)会朝着降低领域分类损失的方向去优化,而GRL之前的参数(特征编码器参数)会朝着提升领域分类损失的方向去优化,即混淆两个领域编码后的特征分布。


TensorFlow动态图实现


我们将完整源码上传到了Github:

https://github.com/hujunxianligong/TF-GRL


这里我们介绍核心组件GRL的实现,在TensorFlow动态图(Eager)模式中,如果要定义一个可以自定义梯度的实现,要定义一个函数,用@tf.custom_gradient注解。函数的返回值格式为:[前向传播结果,梯度计算方法],其中,梯度计算方法也是一个函数,其接收反向传播过来的梯度,返回前向传播每个输入的梯度。因此,GRL层的实现只需要5行代码:

@tf.custom_gradient
def grl(x, alpha):
def grad(dy):
return -dy * alpha, None
return
x, grad


网络的编码器、分类器、领域分类器等的实现与普通MNIST分类器基本没有区别,完整代码请查看GITHUB链接:

https://github.com/hujunxianligong/TF-GRL


参考链接:

  • http://proceedings.mlr.press/v37/ganin15.pdf


-END-

专 · 知

专知,专业可信的人工智能知识分发,让认知协作更快更好!欢迎登录www.zhuanzhi.ai,注册登录专知,获取更多AI知识资料!

欢迎微信扫一扫加入专知人工智能知识星球群,获取最新AI专业干货知识教程视频资料和与专家交流咨询!

请加专知小助手微信(扫一扫如下二维码添加),加入专知人工智能主题群,咨询技术商务合作~

专知《深度学习:算法到实战》课程全部完成!530+位同学在学习,现在报名,限时优惠!网易云课堂人工智能畅销榜首位!

点击“阅读原文”,了解报名专知《深度学习:算法到实战》课程

展开全文
Top
微信扫码咨询专知VIP会员