模型优化漫谈:BERT的初始标准差为什么是0.02?

2021 年 11 月 26 日 PaperWeekly


©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络


前几天在群里大家讨论到了“Transformer 如何解决梯度消失”这个问题,答案有提到残差的,也有提到 LN(Layer Norm)的。这些是否都是正确答案呢?事实上这是一个非常有趣而综合的问题,它其实关联到挺多模型细节,比如“BERT 为什么要 warmup?”、“BERT 的初始化标准差为什么是 0.02?”、“BERT 做 MLM预测之前为什么还要多加一层 Dense?”,等等。本文就来集中讨论一下这些问题。


梯度消失说的是什么意思?

在文章《也来谈谈 RNN 的梯度消失/爆炸问题》中,我们曾讨论过 RNN 的梯度消失问题。事实上,一般模型的梯度消失现象也是类似,它指的是(主要是在模型的初始阶段)越靠近输入的层梯度越小,趋于零甚至等于零,而我们主要用的是基于梯度的优化器,所以梯度消失意味着我们没有很好的信号去调整优化前面的层。

换句话说,前面的层也许几乎没有得到更新,一直保持随机初始化的状态;只有比较靠近输出的层才更新得比较好,但这些层的输入是前面没有更新好的层的输出,所以输入质量可能会很糟糕(因为经过了一个近乎随机的变换),因此哪怕后面的层更新好了,总体效果也不好。最终,我们会观察到很反直觉的现象:模型越深,效果越差,哪怕训练集都如此。

解决梯度消失的一个标准方法就是残差链接,正式提出于 ResNet [1] 中。残差的思想非常简单直接:你不是担心输入的梯度会消失吗?那我直接给它补上一个梯度为常数的项不就行了?最简单地,将模型变成

这样一来,由于多了一条“直通”路 ,就算 中的 梯度消失了, 的梯度基本上也能得以保留,从而使得深层模型得到有效的训练。



LN真的能缓解梯度消失?
然而,在 BERT 和最初的 Transformer 里边,使用的是 Post Norm 设计,它把 Norm 操作加在了残差之后:

其实具体的 Norm 方法不大重要,不管是 Batch Norm 还是 Layer Norm,结论都类似。在文章《浅谈 Transformer 的初始化、参数化与标准化》 [2] 中,我们已经分析过这种 Norm 结构,这里再来重复一下。
在初始化阶段,由于所有参数都是随机初始化的,所以我们可以认为 是两个相互独立的随机向量,如果假设它们各自的方差是 1,那么 的方差就是 2,而 操作负责将方差重新变为 1,那么在初始化阶段, 操作就相当于“除以 ”:

递归下去就是:

我们知道,残差有利于解决梯度消失,但是在 Post Norm 中,残差这条通道被严重削弱了,越靠近输入,削弱得越严重,残差“名存实亡”。所以说,在 Post Norm 的 BERT 模型中,LN 不仅不能缓解梯度消失,它还是梯度消失的“元凶”之一。


那我们为什么还要加LN

那么,问题自然就来了:既然 LN 还加剧了梯度消失,那直接去掉它不好吗?

是可以去掉,但是前面说了, 的方差就是 2 了,残差越多方差就越大了,所以还是要加一个 Norm 操作,我们可以把它加到每个模块的输入,即变为 ,最后的总输出再加个 就行,这就是 Pre Norm 结构,这时候每个残差分支是平权的,而不是像 Post Norm 那样有指数衰减趋势。
当然,也有完全不加 Norm 的,但需要对 进行特殊的初始化,让它初始输出更接近于 0,比如 ReZero、Skip Init、Fixup 等,这些在《浅谈 Transformer 的初始化、参数化与标准化》 [2] 也都已经介绍过了。

但是,抛开这些改进不说,Post Norm 就没有可取之处吗?难道 Transformer 和 BERT 开始就带了一个完全失败的设计?

显然不大可能。虽然 Post Norm 会带来一定的梯度消失问题,但其实它也有其他方面的好处。最明显的是,它稳定了前向传播的数值,并且保持了每个模块的一致性。比如 BERT base,我们可以在最后一层接一个 Dense 来分类,也可以取第 6 层接一个 Dense 来分类;但如果你是 Pre Norm 的话,取出中间层之后,你需要自己接一个 LN 然后再接 Dense,否则越靠后的层方差越大,不利于优化。

其次,梯度消失也不全是“坏处”,其实对于 Finetune 阶段来说,它反而是好处。在 Finetune 的时候,我们通常希望优先调整靠近输出层的参数,不要过度调整靠近输入层的参数,以免严重破坏预训练效果。而梯度消失意味着越靠近输入层,其结果对最终输出的影响越弱,这正好是 Finetune 时所希望的。所以,预训练好的 Post Norm 模型,往往比 Pre Norm 模型有更好的 Finetune 效果,这我们在《RealFormer:把残差转移到 Attention 矩阵上面去》也提到过。


我们真的担心梯度消失吗?

其实,最关键的原因是,在当前的各种自适应优化技术下,我们已经不大担心梯度消失问题了。

这是因为,当前 NLP 中主流的优化器是 Adam 及其变种。对于 Adam 来说,由于包含了动量和二阶矩校正,所以近似来看,它的更新量大致上为
可以看到,分子分母是都是同量纲的,因此分数结果其实就是 的量级,而更新量就是 量级。也就是说,理论上只要梯度的绝对值大于随机误差,那么对应的参数都会有常数量级的更新量;这跟 SGD 不一样,SGD 的更新量是正比于梯度的,只要梯度小,更新量也会很小,如果梯度过小,那么参数几乎会没被更新。
所以,Post Norm 的残差虽然被严重削弱,但是在 base、large 级别的模型中,它还不至于削弱到小于随机误差的地步,因此配合 Adam 等优化器,它还是可以得到有效更新的,也就有可能成功训练了。当然,只是有可能,事实上越深的 Post Norm 模型确实越难训练,比如要仔细调节学习率和 Warmup 等。


Warmup是怎样起作用的?

大家可能已经听说过,Warmup 是Transformer训练的关键步骤,没有它可能不收敛,或者收敛到比较糟糕的位置。为什么会这样呢?不是说有了Adam就不怕梯度消失了吗?

要注意的是,Adam 解决的是梯度消失带来的参数更新量过小问题,也就是说,不管梯度消失与否,更新量都不会过小。但对于 Post Norm 结构的模型来说,梯度消失依然存在,只不过它的意义变了。根据泰勒展开式:
也就是说增量 是正比于梯度的,换句话说,梯度衡量了输出对输入的依赖程度。如果梯度消失,那么意味着模型的输出对输入的依赖变弱了。

Warmup 是在训练开始阶段,将学习率从 0 缓增到指定大小,而不是一开始从指定大小训练。如果不进行 Wamrup,那么模型一开始就快速地学习,由于梯度消失,模型对越靠后的层越敏感,也就是越靠后的层学习得越快,然后后面的层是以前面的层的输出为输入的,前面的层根本就没学好,所以后面的层虽然学得快,但却是建立在糟糕的输入基础上的。

很快地,后面的层以糟糕的输入为基础到达了一个糟糕的局部最优点,此时它的学习开始放缓(因为已经到达了它认为的最优点附近),同时反向传播给前面层的梯度信号进一步变弱,这就导致了前面的层的梯度变得不准。但我们说过,Adam 的更新量是常数量级的,梯度不准,但更新量依然是数量级,意味着可能就是一个常数量级的随机噪声了,于是学习方向开始不合理,前面的输出开始崩盘,导致后面的层也一并崩盘。

所以,如果 Post Norm 结构的模型不进行 Wamrup,我们能观察到的现象往往是:loss 快速收敛到一个常数附近,然后再训练一段时间,loss 开始发散,直至 NAN。如果进行 Wamrup,那么留给模型足够多的时间进行“预热”,在这个过程中,主要是抑制了后面的层的学习速度,并且给了前面的层更多的优化时间,以促进每个层的同步优化。

这里的讨论前提是梯度消失,如果是 Pre Norm 之类的结果,没有明显的梯度消失现象,那么不加 Warmup 往往也可以成功训练。


初始标准差为什么是0.02?
喜欢扣细节的同学会留意到,BERT 默认的初始化方法是标准差为 0.02 的截断正态分布,在《浅谈 Transformer 的初始化、参数化与标准化》 [2] 我们也提过,由于是截断正态分布,所以实际标准差会更小,大约是 。这个标准差是大还是小呢?对于 Xavier 初始化来说,一个 的矩阵应该用 的方差初始化,而 BERT base 的 为 768,算出来的标准差是 。这就意味着,这个初始化标准差是明显偏小的,大约只有常见初始化标准差的一半。
为什么 BERT 要用偏小的标准差初始化呢?事实上,这还是跟 Post Norm 设计有关,偏小的标准差会导致函数的输出整体偏小,从而使得 Post Norm 设计在初始化阶段更接近于恒等函数,从而更利于优化。具体来说,按照前面的假设,如果 的方差是 的方差是 ,那么初始化阶段, 操作就相当于除以 。如果 比较小,那么残差中的“直路”权重就越接近于 1,那么模型初始阶段就越接近一个恒等函数,就越不容易梯度消失。
正所谓“我们不怕梯度消失,但我们也不希望梯度消失”,简单地将初始化标注差设小一点,就可以使得 变小一点,从而在保持 Post Norm 的同时缓解一下梯度消失,何乐而不为?那能不能设置得更小甚至全零?一般来说初始化过小会丧失多样性,缩小了模型的试错空间,也会带来负面效果。综合来看,缩小到标准的 1/2,是一个比较靠谱的选择了。
当然,也确实有人喜欢挑战极限的,最近笔者也看到了一篇文章,试图让整个模型用几乎全零的初始化,还训练出了不错的效果,大家有兴趣可以读读,文章为《ZerO Initialization: Initializing Residual Networks with only Zeros and Ones》 [3]


为什么MLM要多加Dense?

最后,是关于 BERT 的 MLM 模型的一个细节,就是 BERT 在做 MLM 的概率预测之前,还要多接一个 Dense 层和 LN 层,这是为什么呢?不接不行吗?

之前看到过的答案大致上是觉得,越靠近输出层的,越是依赖任务的(Task-Specified),我们多接一个 Dense 层,希望这个 Dense 层是 MLM-Specified 的,然后下游任务微调的时候就不是 MLM-Specified 的,所以把它去掉。这个解释看上去有点合理,但总感觉有点玄学,毕竟 Task-Specified 这种东西不大好定量分析。

这里笔者给出另外一个更具体的解释,事实上它还是跟 BERT 用了 0.02 的标准差初始化直接相关。刚才我们说了,这个初始化是偏小的,如果我们不额外加 Dense 就乘上 Embedding 预测概率分布,那么得到的分布就过于均匀了(Softmax 之前,每个 logit 都接近于 0),于是模型就想着要把数值放大。

现在模型有两个选择:第一,放大 Embedding 层的数值,但是 Embedding 层的更新是稀疏的,一个个放大太麻烦;第二,就是放大输入,我们知道 BERT 编码器最后一层是 LN,LN 最后有个初始化为 1 的 gamma 参数,直接将那个参数放大就好。

模型优化使用的是梯度下降,我们知道它会选择最快的路径,显然是第二个选择更快,所以模型会优先走第二条路。这就导致了一个现象:最后一个 LN 层的 gamma 值会偏大。如果预测 MLM 概率分布之前不加一个 Dense+LN,那么  BERT 编码器的最后一层的 LN 的 gamma 值会偏大,导致最后一层的方差会比其他层的明显大,显然不够优雅;而多加了一个 Dense+LN 后,偏大的 gamma 就转移到了新增的 LN 上去了,而编码器的每一层则保持了一致性。

事实上,读者可以自己去观察一下 BERT 每个 LN 层的 gamma 值,就会发现确实是最后一个 LN 层的 gamma 值是会明显偏大的,这就验证了我们的猜测~


希望大家多多海涵批评斧正
本文试图回答了 Transformer、BERT 的模型优化相关的几个问题,有一些是笔者在自己的预训练工作中发现的结果,有一些则是结合自己的经验所做的直观想象。不管怎样,算是分享一个参考答案吧,如果有不当的地方,请大家海涵,也请各位批评斧正。


参考文献

[1] https://arxiv.org/abs/1512.03385
[2] https://kexue.fm/archives/8620
[3] https://arxiv.org/abs/2110.12661


特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。



更多阅读




#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·

登录查看更多
0

相关内容

在机器学习中,使用基于梯度的学习方法和反向传播训练人工神经网络时,会遇到梯度消失的问题。在这种方法中,每个神经网络的权值在每次迭代训练时都得到一个与误差函数对当前权值的偏导数成比例的更新。问题是,在某些情况下,梯度会极小,有效地阻止权值的改变。在最坏的情况下,这可能会完全阻止神经网络进一步的训练。作为问题原因的一个例子,传统的激活函数,如双曲正切函数的梯度在范围(0,1),而反向传播通过链式法则计算梯度。这样做的效果是将n个这些小数字相乘来计算n层网络中“前端”层的梯度,这意味着梯度(误差信号)随着n的增加呈指数递减,而前端层的训练非常缓慢。
专知会员服务
17+阅读 · 2021年8月6日
专知会员服务
28+阅读 · 2021年8月2日
专知会员服务
31+阅读 · 2021年7月19日
专知会员服务
61+阅读 · 2021年2月16日
专知会员服务
16+阅读 · 2020年7月27日
【ICML 2020 】小样本学习即领域迁移
专知会员服务
78+阅读 · 2020年6月26日
【伯克利】再思考 Transformer中的Batch Normalization
专知会员服务
41+阅读 · 2020年3月21日
听说Attention与Softmax更配哦~
PaperWeekly
0+阅读 · 2022年4月9日
多任务学习漫谈:分主次之序
PaperWeekly
0+阅读 · 2022年3月7日
多任务学习漫谈:行梯度之事
PaperWeekly
0+阅读 · 2022年2月18日
多任务学习漫谈:以损失之名
PaperWeekly
1+阅读 · 2022年1月26日
输入梯度惩罚与参数梯度惩罚的一个不等式
PaperWeekly
0+阅读 · 2021年12月27日
Dropout视角下的MLM和MAE:一些新的启发
PaperWeekly
1+阅读 · 2021年12月6日
激活函数还是有一点意思的!
计算机视觉战队
12+阅读 · 2019年6月28日
面试题:Word2Vec中为什么使用负采样?
七月在线实验室
46+阅读 · 2019年5月16日
BERT大火却不懂Transformer?读这一篇就够了
大数据文摘
11+阅读 · 2019年1月8日
详解常见的损失函数
七月在线实验室
20+阅读 · 2018年7月12日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
3+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2011年12月31日
国家自然科学基金
1+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2008年12月31日
Arxiv
1+阅读 · 2022年4月18日
Arxiv
0+阅读 · 2022年4月16日
Arxiv
0+阅读 · 2022年4月15日
Arxiv
21+阅读 · 2019年8月21日
How to Fine-Tune BERT for Text Classification?
Arxiv
13+阅读 · 2019年5月14日
Arxiv
11+阅读 · 2018年1月18日
VIP会员
相关VIP内容
专知会员服务
17+阅读 · 2021年8月6日
专知会员服务
28+阅读 · 2021年8月2日
专知会员服务
31+阅读 · 2021年7月19日
专知会员服务
61+阅读 · 2021年2月16日
专知会员服务
16+阅读 · 2020年7月27日
【ICML 2020 】小样本学习即领域迁移
专知会员服务
78+阅读 · 2020年6月26日
【伯克利】再思考 Transformer中的Batch Normalization
专知会员服务
41+阅读 · 2020年3月21日
相关资讯
听说Attention与Softmax更配哦~
PaperWeekly
0+阅读 · 2022年4月9日
多任务学习漫谈:分主次之序
PaperWeekly
0+阅读 · 2022年3月7日
多任务学习漫谈:行梯度之事
PaperWeekly
0+阅读 · 2022年2月18日
多任务学习漫谈:以损失之名
PaperWeekly
1+阅读 · 2022年1月26日
输入梯度惩罚与参数梯度惩罚的一个不等式
PaperWeekly
0+阅读 · 2021年12月27日
Dropout视角下的MLM和MAE:一些新的启发
PaperWeekly
1+阅读 · 2021年12月6日
激活函数还是有一点意思的!
计算机视觉战队
12+阅读 · 2019年6月28日
面试题:Word2Vec中为什么使用负采样?
七月在线实验室
46+阅读 · 2019年5月16日
BERT大火却不懂Transformer?读这一篇就够了
大数据文摘
11+阅读 · 2019年1月8日
详解常见的损失函数
七月在线实验室
20+阅读 · 2018年7月12日
相关基金
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
3+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2011年12月31日
国家自然科学基金
1+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2008年12月31日
相关论文
Top
微信扫码咨询专知VIP会员