多任务学习漫谈:以损失之名

2022 年 1 月 26 日 PaperWeekly


©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络


能提升模型性能的方法有很多,多任务学习(Multi-Task Learning)也是其中一种。简单来说,多任务学习是希望将多个相关的任务共同训练,希望不同任务之间能够相互补充和促进,从而获得单任务上更好的效果(准确率、鲁棒性等)。然而,多任务学习并不是所有任务堆起来就能生效那么简单,如何平衡每个任务的训练,使得各个任务都尽量获得有益的提升,依然是值得研究的课题。

最近,笔者机缘巧合之下,也进行了一些多任务学习的尝试,借机也学习了相关内容,在此挑部分结果与大家交流和讨论。


加权求和

从损失函数的层面看,多任务学习就是有多个损失函数 ,一般情况下它们有大量的共享参数、少量的独立参数,而我们的目标是让每个损失函数都尽可能地小。为此,我们引入权重 ,通过加权求和的方式将它转化为如下损失函数的单任务学习:
在这个视角下,多任务学习的主要难点就是如何确定各个 了。


初始状态
按道理,在没有任务先验和偏见的情况下,最自然的选择就是平等对待每个任务,即 。然而,事实上每个任务可能有很大差别,比如不同类别数的分类任务混合、分类与回归任务混合、分类与生成任务混合等等,从物理的角度看,每个损失函数的量纲和量级都不一样,直接相加是没有意义的。
如果我们将每个损失函数看成具有不同量纲的物理量,那么从“无量纲化”的思想出发,我们可以用损失函数的初始值倒数作为权重,即

其中 表示任务 的初始损失值。该式关于每个 是“齐次”的,所以它的一个明显优点是缩放不变性,即如果让任务 的损失乘上一个常数,那么结果不会变化。此外,由于每个损失都除以了自身的初始值,较大的损失会缩小,较小的损失会放大,从而使得每个损失能够大致得到平衡。
那么,怎么估计 呢?最直接的方法当然是直接拿几个 batch 的数据来估算一下。除此之外,我们可以基于一些假设得到一个理论值。比如,在主流的初始化之下,我们可以认为初始模型(加激活函数之前)的输出是一个零向量,如果加上 softmax 则是均匀分布,那么对于一个“ 分类+交叉熵”问题,它的初始损失就是 ;对于“回归+ L2 损失”问题,则可以用零向量来估计初始损失,即 是训练集的全体标签。


先验状态
用初始损失的一个问题是初始状态不一定能很好地反应当前任务的学习难度,更好的方案应该是将“初始状态”改为“先验状态”:

比如,如果 分类中每个类的频率分别是 (先验分布),那么虽然初始状态的预测分布为均匀分布,但我们可以合理地认为模型可以很容易学会将每个样本的结果都预测为 ,此时模型的损失为熵

某种意义上来说,“先验分布”比“初始分布”更能体现出“初始”的本质,它是“就算模型啥都学不会,也知道按照先验分布来随机出结果”的体现,所以此时的损失值更能代表当前任务的初始难度,因此用 代替 应该更加合理;类似地,对于“回归+L2损失”问题,它的先验结果应该是全体标签的期望 ,所以我们用 代替 ,有望取得更合理的结果。


动态调节
不管是用初始状态的式(2)还是先验状态的式(3),它们的任务权重在确定之后就保持不变了,并且它们确定权重的方法不依赖于学习过程。然而,尽管我们可以通过先验分布等信息简单感知一下学习难度,但究竟有多难其实要真正去学习才知道,所以更合理的方案应该是根据训练进程动态地调整权重。


实时状态
纵观前文,式(2)和式(3)的核心思想都是用损失值的倒数来作为任务权重,那么能不能干脆用“实时”的损失值倒数来实现动态调整权重?即:

这里的 的简写。在这个方案中,每个任务的损失函数都被调整恒为 1,所以不管是量纲还是量级上都是一致的。由于 算子的存在,虽然损失恒为 1,但梯度并非恒为 0:

简单来说就是加上 算子后,它的值不变,但是导数为 0,所以最终结果就是以动态权重 来实时调整了梯度的比例。很多“民间实验”表明,式(5)确实在多数情况下都可以作为一个相当不错的 baseline。


等价梯度
我们可以从另一个角度来看该方案。从式(6)我们可以得到:

因此从梯度上看,式(5)与 没有实质区别,而我们进一步有:

由于 是单调递增的,所以式(5)与下式在梯度方向上是一致:




广义平均
显然,上式正是 的“几何平均”,而如果我们约定 恒等于 ,那么原始的式(1)就是 的“代数平均”。也就是说,我们发现这一系列的推导其实隐藏了从代数平均到几何平均的转变,这启发我们或许可以考虑“广义平均”:

也就是将每个损失函数算 次方后再平均最后再开 次方,这里的 可以是任意实数,代数平均对应 ,而几何平均对应 (需要取极限)。可以证明, 是关于 的单调递增函数,并且有:

这就意味着,当 增大时,模型愈发关心损失中的最大值,反之则更关心损失中的最小值。这样一来,虽然依然存在超参数 要调整,但是相比于原始的式(1),超参数的个数已经从 个变为只有 1 个,简化了调参过程。


平移不变
重新回顾式(2)、式(3)和式(5),它们都是通过每个任务损失除以自身的某个状态来调节权重,并且获得了缩放不变性。然而,尽管它们都具备了缩放不变性,但却失去了更基本的“平移不变性”,也就是说,如果每个损失都加上一个常数,(2)、式(3)和式(5)的梯度方向是有可能改变的,这对于优化来说并不是一个好消息,因为原则上来说常数没有带来任何有意义的信息,优化结果不应该随之改变。


理想目标
一方面,我们用损失函数(的某个状态)的倒数作为当前任务的权重,但损失函数的导数不具备平移不变性;另一方面,损失函数可以理解为当前模型与目标状态的距离,而梯度下降本质上是在寻找梯度为 0 的点,所以梯度的模长其实也能起到类似作用,因此我们可以用梯度的模长来替换掉损失函数,从而将式(5)变成:

跟损失函数的一个明显区别是,梯度模长显然具备平移不变性,并且分子分母关于 依然是齐次的,所以上式还保留了缩放不变性。因此,这是一个能同时具备平移和缩放不变性的理想目标。


梯度归一
对式(12)求梯度,我们得到:

可以看到,式(12)本质上是将每个任务损失的梯度进行归一化后再把梯度累加起来。它同时也告诉了我们一种实现方案,即可以让每个任务依次训练,每次只训练一个任务,然后将每个任务的梯度归一化后累积起来再更新,这样就免除了在定义损失函数的时候就要算梯度的麻烦了。

关于梯度归一化,笔者能找到相关工作是《GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks》 [1] ,它本质上是式(2)和式(13)的混合,里边也包含了对梯度模长重新标定的思想,但却要通过额外的优化来确定任务权重,个人认为显得繁琐和冗余了。


本文小结
在损失函数的视角下,多任务学习的关键问题是如何调节每个任务的权重来平衡各自的损失,本文从缩放不变和平移不变两个角度介绍了一些参考做法,并补充了“广义平均”的概念,将多个任务的权重调节转化为单个参数的调节问题,可以简化调参难度。

参考文献

[1] https://arxiv.org/abs/1711.02257




特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。




更多阅读




#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·

登录查看更多
1

相关内容

多任务学习(MTL)是机器学习的一个子领域,可以同时解决多个学习任务,同时利用各个任务之间的共性和差异。与单独训练模型相比,这可以提高特定任务模型的学习效率和预测准确性。多任务学习是归纳传递的一种方法,它通过将相关任务的训练信号中包含的域信息用作归纳偏差来提高泛化能力。通过使用共享表示形式并行学习任务来实现,每个任务所学的知识可以帮助更好地学习其它任务。
【博士论文】多任务学习视觉场景理解,140页pdf
专知会员服务
90+阅读 · 2022年4月5日
【ICLR2022】基于任务相关性的元学习泛化边界
专知会员服务
18+阅读 · 2022年2月8日
【NeurIPS2021】序一致因果图的多任务学习
专知会员服务
19+阅读 · 2021年11月7日
专知会员服务
14+阅读 · 2021年7月24日
专知会员服务
16+阅读 · 2021年7月13日
专知会员服务
22+阅读 · 2021年6月22日
《多任务学习》最新综述论文,20页pdf
专知会员服务
123+阅读 · 2021年4月6日
最新《多任务学习》综述,39页pdf
专知会员服务
264+阅读 · 2020年7月10日
多任务学习漫谈:分主次之序
PaperWeekly
0+阅读 · 2022年3月7日
多任务学习漫谈:行梯度之事
PaperWeekly
0+阅读 · 2022年2月18日
模型优化漫谈:BERT的初始标准差为什么是0.02?
PaperWeekly
0+阅读 · 2021年11月26日
标签间相关性在多标签分类问题中的应用
人工智能前沿讲习班
22+阅读 · 2019年6月5日
详解常见的损失函数
七月在线实验室
20+阅读 · 2018年7月12日
迁移学习之Domain Adaptation
全球人工智能
18+阅读 · 2018年4月11日
从最大似然到EM算法:一致的理解方式
PaperWeekly
18+阅读 · 2018年3月19日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
5+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
3+阅读 · 2012年12月31日
国家自然科学基金
2+阅读 · 2011年12月31日
Arxiv
21+阅读 · 2020年10月11日
Arxiv
11+阅读 · 2018年1月18日
VIP会员
相关VIP内容
【博士论文】多任务学习视觉场景理解,140页pdf
专知会员服务
90+阅读 · 2022年4月5日
【ICLR2022】基于任务相关性的元学习泛化边界
专知会员服务
18+阅读 · 2022年2月8日
【NeurIPS2021】序一致因果图的多任务学习
专知会员服务
19+阅读 · 2021年11月7日
专知会员服务
14+阅读 · 2021年7月24日
专知会员服务
16+阅读 · 2021年7月13日
专知会员服务
22+阅读 · 2021年6月22日
《多任务学习》最新综述论文,20页pdf
专知会员服务
123+阅读 · 2021年4月6日
最新《多任务学习》综述,39页pdf
专知会员服务
264+阅读 · 2020年7月10日
相关资讯
多任务学习漫谈:分主次之序
PaperWeekly
0+阅读 · 2022年3月7日
多任务学习漫谈:行梯度之事
PaperWeekly
0+阅读 · 2022年2月18日
模型优化漫谈:BERT的初始标准差为什么是0.02?
PaperWeekly
0+阅读 · 2021年11月26日
标签间相关性在多标签分类问题中的应用
人工智能前沿讲习班
22+阅读 · 2019年6月5日
详解常见的损失函数
七月在线实验室
20+阅读 · 2018年7月12日
迁移学习之Domain Adaptation
全球人工智能
18+阅读 · 2018年4月11日
从最大似然到EM算法:一致的理解方式
PaperWeekly
18+阅读 · 2018年3月19日
相关基金
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
5+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
3+阅读 · 2012年12月31日
国家自然科学基金
2+阅读 · 2011年12月31日
Top
微信扫码咨询专知VIP会员