BERT用的LayerNorm可能不是你认为的那个Layer Norm?

2022 年 9 月 5 日 PaperWeekly


©作者 | 王坤泽

单位 | 悉尼大学

研究方向 | NLP


有关 Batch norm 和 Layer norm 的比较可以算上是算法领域的八股文了,为什么 BERT 不用 batch norm 而用 layer norm 的问题都被问烂了,知乎上随便一搜都有很多人讲解 BN 和 LN 的区别。通常来说大家都会给这张图:

▲ BN vs LN

大家会说,针对 CV 和 NLP 两种问题,这里的三个维度表示的信息不同:


如果只看 NLP 问题,假设我们的 batch 是(2,3,4)的,也就是 batch_size = 2, seq_length = 3, dim = 4 的,假设第一个句子是 w1 w2 w3,第二个句子是 w4 w5 w6,那么这个 tensor 可以写为:

[[w11, w12, w13, w14], [w21, w22, w23, w24], [w31, w32, w33, w34]
[w41, w42, w43, w44], [w51, w52, w53, w54], [w61, w62, w63, w64]]


我们发现,如果是 BN 的话,会对同一个 batch 里对应位置上的 token 求平均值,也就是说 (w11+w12+w13+w14+w41+w42+w43+w44)/8是其中一个 mean,一共会求出 3 个 mean,也就是上图里 C 个(seq_length)个 mean。

但是如果是 LN 的话,看起来是对每个 sample 里的所有 feature 求 mean,也就是(w11+w12+w13+w14+w21+w22+w23+w24+w31+w32+w33+w34)/12,可以求出一共 2 个 mean,也就是图里 N(batch_size)个 mean。

我一直对这个计算深信不疑,认为 BERT 里也是这样的实现,但是有一天我在这个回答看到了 @猛猿  的这个回答:为什么 Transformer 要用 LayerNorm? [1]  其中作者给出了两张图:


▲ 都是 Layer norm 但是却不一样

左图和我们认为的 LN 一致,也是我一直认为的 LN,但是右图却是在一个 token 上求平均,带回我们原来的问题,对于一个(2,3,4)的 tensor,(w11+w12+w13+w14)/4 是一个 mean,一共会有 2*3=6 个 mean。

那到底,BERT 里是 batch_size个mean(左图的计算方法),还是 batch_size*seq_length 个 mean(右图的计算方法)呢?我们得看看源码。

BERT 或者说 transformer encoder 的 pytorch 源码比较著名的应该是 torch 自带的 transformer encoder 和 hugging face 自己写的,我们一个个看。

# torch.nn.TransformerEncoderLayer
# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py
# 412行
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

# huggingface bert_model
# https://github.com/huggingface/transformers/blob/3223d49354e41dfa44649a9829c7b09013ad096e/src/transformers/models/bert/modeling_bert.py#L378
# 382行
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

可以看到,无论是火炬自带还是捧着脸复现的 transformer encoder 或者叫 bert layer,里面用的都是 torch 自己的 nn.LayerNorm,并且参数都是对应为 768 的 hidden dimension(变形金刚把它叫做 d_model,波特把它叫做 hidden_size)。

那我们看看 nn.LayerNorm(dim) 是一个什么效果,以下代码修改自 Understanding torch.nn.LayerNorm in nlp [2]


import torch

batch_size, seq_size, dim = 234
embedding = torch.randn(batch_size, seq_size, dim)

layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)
print("y: ", layer_norm(embedding))

eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)

print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))


在以上代码中,我先生成了一个 emb,然后使用 nn.LayerNorm(dim) 计算它 layer nrom 后的结果,同时,我手动计算了一个在最后一维上的 mean(也就是说我的 mean 的维度是 2*3,也就是一共 6 个 mean),如果这样算出来的结果和我调 nn.LayerNorm(dim) 一致,那就说明,nn.LayerNorm(dim) 会给我们 (batch_size*seq_length) 个 mean,也就是刚才上图里右边的方法。计算后结果如下:

y:  tensor([[[-0.2500,  1.0848,  0.6808-1.5156],
         [-1.1630-0.7052,  1.3840,  0.4843],
         [-1.3510,  0.4520-0.4354,  1.3345]],

        [[ 0.4372-0.4610,  1.3527-1.3290],
         [ 0.2282,  1.3853-0.2037-1.4097],
         [-0.9960-0.6184-0.0059,  1.6203]]])
mean:  torch.Size([231])
y_custom:  tensor([[[-0.2500,  1.0848,  0.6808-1.5156],
         [-1.1630-0.7052,  1.3840,  0.4843],
         [-1.3510,  0.4520-0.4354,  1.3345]],

        [[ 0.4372-0.4610,  1.3527-1.3290],
         [ 0.2282,  1.3853-0.2037-1.4097],
         [-0.9960-0.6184-0.0059,  1.6203]]])


确实一致,也就是说, 至少在 torch 自带和 hugging face 复现的 bert 里,layernorm 实际上和右图一致是对每个 token 的 feature 单独求 mean

那么如果我们想像左图里求出 batch_size 个 mean,怎么用 nn.LayerNorm 实现呢?只需要修改 nn.LayerNorm 的参数为 nn.LayerNorm([seq_size,dim]) 即可,代码如下,大家可以跑一下,发现这样和求 batch_size 个 mean 是一致的:


import torch

batch_size, seq_size, dim = 234
embedding = torch.randn(batch_size, seq_size, dim)

layer_norm = torch.nn.LayerNorm([seq_size,dim], elementwise_affine = False)
print("y: ", layer_norm(embedding))

eps: float = 0.00001
mean = torch.mean(embedding[:, :, :], dim=(-2,-1), keepdim=True)
var = torch.square(embedding[:, :, :] - mean).mean(dim=(-2,-1), keepdim=True)

print("mean: ", mean.shape)
print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

最后一个问题,按图右这么求,那岂不是和 InstanceNorm 一样了吗?同样我做了一个代码实验:


from torch.nn import InstanceNorm2d
instance_norm = InstanceNorm2d(3, affine=False)
x = torch.randn(234)
output = instance_norm(x.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入
print(output.reshape(2,3,4))

layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)
print(layer_norm(x))

可以跑一下,发现确实是一致的。

结论:BERT 里的 layernorm 在 torch 自带的  transformer encoder 和 hugging face 复现的 bert 里,实际上都是在做 InstanceNorm。

那么,最开始 Vaswani 在 attention is all you need 里提出的使用 layernorm 是什么呢?tf.tensor2tensor 的作者也是 Vaswani,那么我认为 tf.tensor2tensor 应该是符合作者最初的源码设计的,通过翻阅源码(看了无数的文件,大家可以试试,真的很多,各种 function 封装...),我确认了作者自己的代码里的 layernorm 使用的参数也是最后一维的 dimension,那么也就是说, 原作者本质上也是用的 InstanceNorm

最后想问问,InstanceNorm 是 LayerNorm 的一种吗?为啥我没看到相关的说法?


参考文献

[1]  https://www.zhihu.com/question/487766088/answer/2309239401
[2]  h ttps://stackoverflow.com/questions/70065235/understanding-torch-nn-layernorm-in-nlp


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·

登录查看更多
0

相关内容

代码注释最详细的Transformer
专知会员服务
110+阅读 · 2022年6月30日
【NeurIPS 2020】依图推出预训练语言理解模型ConvBERT
专知会员服务
11+阅读 · 2020年11月13日
【NeurIPS 2020】融入BERT到并行序列模型
专知会员服务
25+阅读 · 2020年10月15日
专知会员服务
123+阅读 · 2020年9月8日
专知会员服务
15+阅读 · 2020年7月27日
【伯克利】再思考 Transformer中的Batch Normalization
专知会员服务
40+阅读 · 2020年3月21日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
恕我直言,你们的模型训练都还不够快
极市平台
2+阅读 · 2022年5月2日
为什么Pre Norm的效果不如Post Norm?
PaperWeekly
0+阅读 · 2022年5月1日
从ICLR 2022看什么是好的图神经网络?
PaperWeekly
0+阅读 · 2022年2月18日
模型优化漫谈:BERT的初始标准差为什么是0.02?
PaperWeekly
0+阅读 · 2021年11月26日
无监督分词和句法分析!原来BERT还可以这样用
PaperWeekly
12+阅读 · 2020年6月17日
如何区分并记住常见的几种 Normalization 算法
极市平台
19+阅读 · 2019年7月24日
BERT源码分析PART I
AINLP
38+阅读 · 2019年7月12日
BERT大火却不懂Transformer?读这一篇就够了
大数据文摘
11+阅读 · 2019年1月8日
GAN的数学原理
算法与数学之美
14+阅读 · 2017年9月2日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Arxiv
0+阅读 · 2022年11月25日
Arxiv
0+阅读 · 2022年11月23日
Heterogeneous Deep Graph Infomax
Arxiv
12+阅读 · 2019年11月19日
Arxiv
11+阅读 · 2019年6月19日
VIP会员
相关VIP内容
代码注释最详细的Transformer
专知会员服务
110+阅读 · 2022年6月30日
【NeurIPS 2020】依图推出预训练语言理解模型ConvBERT
专知会员服务
11+阅读 · 2020年11月13日
【NeurIPS 2020】融入BERT到并行序列模型
专知会员服务
25+阅读 · 2020年10月15日
专知会员服务
123+阅读 · 2020年9月8日
专知会员服务
15+阅读 · 2020年7月27日
【伯克利】再思考 Transformer中的Batch Normalization
专知会员服务
40+阅读 · 2020年3月21日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
相关资讯
恕我直言,你们的模型训练都还不够快
极市平台
2+阅读 · 2022年5月2日
为什么Pre Norm的效果不如Post Norm?
PaperWeekly
0+阅读 · 2022年5月1日
从ICLR 2022看什么是好的图神经网络?
PaperWeekly
0+阅读 · 2022年2月18日
模型优化漫谈:BERT的初始标准差为什么是0.02?
PaperWeekly
0+阅读 · 2021年11月26日
无监督分词和句法分析!原来BERT还可以这样用
PaperWeekly
12+阅读 · 2020年6月17日
如何区分并记住常见的几种 Normalization 算法
极市平台
19+阅读 · 2019年7月24日
BERT源码分析PART I
AINLP
38+阅读 · 2019年7月12日
BERT大火却不懂Transformer?读这一篇就够了
大数据文摘
11+阅读 · 2019年1月8日
GAN的数学原理
算法与数学之美
14+阅读 · 2017年9月2日
相关基金
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
1+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Top
微信扫码咨询专知VIP会员