模型压缩实践收尾篇——模型蒸馏以及其他一些技巧实践小结

2020 年 5 月 28 日 AINLP

作者:邱震宇(华泰证券股份有限公司 算法工程师)

知乎专栏:我的ai之路





最近一段时间,看了之前几篇文章的同学都知道我的研究重心在模型压缩这块。因为目前以bert为典型的大规模预训练模型在很多NLP任务上的效果都很香,我们团队这边也有很多的应用,然而其在线上的inference性能(尤其是在非GPU服务器上)在一定程度上降低了用户体验。因此,我研究了当前模型压缩领域中比较work的一些方法,并通过实验对比,总结了一些方法的有效性。之前两篇文章已经给出了一些成果,本文以模型蒸馏为主题,总结一些在实践中有效的压缩方法,作为这段时间对模型压缩研究的结尾。由于最近github网速很不稳定,后续会把相关代码上传到我NER的github repo上。

本文重点不会放在深入描述某个模型蒸馏论文或者方法上,各位同学可以参考其他高质量的综述论文或者总结博客,有些文章写的真的很好,比如mitchgordon.me/machine/ ,另外还有zhihu.com/people/rumor- 这位博主写的一些文章以及等等。相对的,本文将重点放在对这些方法的实践,使用以及结果比对上,希望能作为一些业界的同学在尝试模型蒸馏时的参考。

声明:本文的所有实验结果仅针对之前介绍过的Chinese NER任务,同时蒸馏作用的阶段只是在下游任务finetune阶段,不会涉及预训练(--__-- 没钱没时间)。后续有时间的话还会对其他类型的任务(如文本分类)等做一些实验,本文不会包含这些内容。

模型蒸馏

模型蒸馏的主要流程是先用完整复杂模型使用训练集训练出来一个teacher模型,然后设计一个小规模的student模型,再固定teacher模型的权重参数,然后设计一系列loss,让student模型在蒸馏学习的过程中逐渐向teacher模型的表现特性靠拢,使得student模型的预测精度逐渐逼近teacher模型。

其中,专门针对Bert模型的蒸馏方法有很多,如tinybert,distillBert,pkd-bert等等。虽然有这么多蒸馏方法,但是仔细研究也能发现它们或多或少都有一些共同点,例如:

1、在预训练阶段使用蒸馏方法通常能够取得较好的效果。

2、设计的loss都有一些共通性。

3、会将模型架构模块化,然后对模型不同的模块设计不同的loss。

下面我就从上述几个角度分别总结一下。

预训练阶段&finetune阶段

关于这块内容,我没有做过多的实践,因为目前不具备做预训练的条件。不过基本上所有bert蒸馏方法在预训练阶段使用都能获取不错的效果,有的甚至能在裁剪一定规模的情况下,保持或者超越原始的模型。我主要关注的是将蒸馏方法仅作用在finetune阶段。经过实验,发现仅在对finetune后的原始模型进行蒸馏,很难保持原始的精度,或多或少都会有一定程度的精度损失。我们能做的就是在inference性能和inference精度两边做一定的平衡。比如要考虑裁剪的bert层数,裁剪的中间层神经元数、注意力头数等,通常裁剪得越多,inference的精度损失就越大。按照我之前两篇文章中的方法,使用layerdrop裁剪一半的层数会有8-10个百分点的下降,而使用bert-theseus方法裁剪一半层数,会有2个百分点的下降。对于bert-theseus来说,完全可以应用到实际的项目服务中。

本次使用蒸馏方式在finetune阶段裁剪模型,在裁剪一半层数的情况下,精度下降的幅度从1个百分点到5个百分点之间浮动,下面会分别具体介绍不同方法带来的结果。

蒸馏的loss设计

bert蒸馏中的loss设计可以说是其精髓,这里就结合上面2,3两点一起来介绍一下。总结一下,对于当前所有的bert模型,主要设计的loss的模块集中在output层的logits输出(或者softmax概率化后的输出),中间层的hidden_output,attention_output,embedding神经元等。

logits的loss设计

对于logits来说,通常使用mean squared error来计算两个logits之间的差异性。但是对于模型的不同组件,计算两个logits的loss方式也有细微的不同。以NER这种序列标注任务为例,在对中间隐层和output层的logits计算mse时,只需要考虑正常的batch中的序列mask,不要将所有序列step中的padding部分都计算mse就可以了。但是对于attention的output(在bert代码中为attention_scores,即softmax概率化之前的attention计算结果),其shape为[batch_size,head,seq_len,seq_len],需要考虑最后两个维度上的mask,这里参考了TextBrewer(TextBrewer)中的实现,其官方的pytorch代码如下:

def att_mse_loss(attention_S, attention_T, mask=None):
'''
* Calculates the mse loss between `attention_S` and `attention_T`.
* If the `inputs_mask` is given, masks the positions where ``input_mask==0``.
:param torch.Tensor logits_S: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
:param torch.Tensor logits_T: tensor of shape (*batch_size*, *num_heads*, *length*, *length*)
:param torch.Tensor mask: tensor of shape (*batch_size*, *length*)
'''

if mask is None:
attention_S_select = torch.where(attention_S <= -1e-3, torch.zeros_like(attention_S), attention_S)
attention_T_select = torch.where(attention_T <= -1e-3, torch.zeros_like(attention_T), attention_T)
loss = F.mse_loss(attention_S_select, attention_T_select)
else:
mask = mask.to(attention_S).unsqueeze(1).expand(-1, attention_S.size(1), -1) # (bs, num_of_heads, len)
valid_count = torch.pow(mask.sum(dim=2),2).sum()
loss = (F.mse_loss(attention_S, attention_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(2)).sum() / valid_count
return loss

大家注意一个细节,当给定mask时,在计算valid_count时,作者使用的是  。而在tinybert中的实现则是这样的:

for student_att, teacher_att in zip(student_atts, new_teacher_atts):
student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(device),
student_att)
teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(device),
teacher_att)

tmp_loss = loss_mse(student_att, teacher_att)

这个与TextBrewer中未给出mask时的计算方式相同。两种方式我都尝试过,对于最终模型的精度基本上没有太大区别,两种实现方式效果都差不多。

概率分布的loss设计

也有一些方法专门针对概率化后的信息输出计算其loss,比如每一层attention_score概率化后的alignment、最终模型输出的概率化结果。通常来说,会使用交叉熵或者KL-divergence等方法计算两个概率分布之间的差异。

对于output的概率输出来说,在计算两个模型输出之间的交叉熵之前,需要先对模型的概率分布进行一个flat操作。原因在于我们常规的模型学习完成后,它学习到的概率分布都是比较陡的,即某一个或者极少一部分类别的概率会非常大,其余类别的会非常小,因为模型已经学到了一些成熟的信息。在蒸馏时,我们要让student模型学习到teacher模型的概率输出,如果还保持之前的概率分布,那么会让大部分的概率信息无法被学习。因此,通常在对logits进行概率化之前,要先对logits除以一个temperature,让不同类别的概率差异稍微变小一点。同样参考TextBrewer中的代码:

def kd_ce_loss(logits_S, logits_T, temperature=1):
'''
Calculate the cross entropy between logits_S and logits_T
:param logits_S: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
:param logits_T: Tensor of shape (batch_size, length, num_labels) or (batch_size, num_labels)
:param temperature: A float or a tensor of shape (batch_size, length) or (batch_size,)
'''

if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
temperature = temperature.unsqueeze(-1)
beta_logits_T = logits_T / temperature
beta_logits_S = logits_S / temperature
p_T = F.softmax(beta_logits_T, dim=-1)
loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
return loss

对于attention概率输出来说,通常不需要对概率分布进行平滑操作,只需要进行正常的交叉熵或者KL-divergence操作,同时要考虑到mask,对于logits来说很小的负值代表是一个被mask的维度,而对于概率分布来说就不是这个情况了,这个时候最好是能够提供序列的mask。

finetune任务自身loss

除了几个蒸馏的loss之外,将下游任务的loss也加入到模型蒸馏的整体任务中,也能让student模型学习到下游任务的信息。

实验

我设计了一个消融对比实验,实验因子包括两大类,一类是对不同模型组件的蒸馏,一类是具体的loss方法,总共包含如下几种情况:

1、是否使用finetune任务自身loss。

2、是否使用attention output输出logits的mse

3、是否使用hidden output输出logits的mse

4、对比使用output的输出概率的cross entropy和logits的mse

5、对比使用attention output输出概率的ce和logits的mse

其中,我将output输出logits的mse作为基础的蒸馏方法,该loss一直存在。

在我实验的任务上,结果如下:

1、使用finetune任务自身的loss是有效的,但是效果不大,大概能够提升0.2个百分点。

2、使用attention output输出logits的mse效果甚微,基本没有太大提升。我推测可能是当前对于序列标注任务来说,attention的学习提升不大。建议使用更多不同的任务来实验。

3、使用hidden output输出logits的mse是非常有效的,能够提升1个百分点。

4、使用概率输出做蒸馏和使用logits输出做蒸馏差距不大,并不能看到显著的区别,建议用更多不同的任务来实验。

其他技巧

在蒸馏实验中,我还尝试了很多其他的trick实验,总结了两条有用的最佳实践候选,供同学们参考使用。

bert层的映射设计

在设计attention output的loss时,由于会对bert的层数进行裁剪,所以需要对student的encode层和原始模型中的encode层进行映射。之前有同学介绍过微软的一篇论文:miniLM ,它只使用bert最后一层的value概率输出和attention概率输出做蒸馏,省去了设计映射的工作。我实验过,并没有得到很好的效果,其有效性还待验证,后续会用其他类型的任务来验证。至于如何设计层的映射,目前还没有一个方法论,通常和任务还是相关的,但是有一些指导意见还是可以参考。如bert的每一层所学习存储的信息重点都是不一样的,越接近embedding的底层会倾向于学习通用基础的语言学知识。而接近下游任务分类的上层,则会倾向于学习下游任务中的具体信息。另外间隔的层之间的连通性比较好,因此通常会以间隔的方式建立层映射,如student中的0-5层可以分别对应原模型中的1,3,5,7,9,11层。

尽量沿用teacher模型的权重

在进行模型蒸馏时,通常会初始化student模型的权重从头开始训练。但是,如果能让student模型在一开始就用teacher模型的部分权重进行初始化,不仅能够提升学习效率,最后得到的精度也是不错的。通过实验发现,使用teacher模型权重初始化student模型,至少能够带来5个百分点的性能提升。

然而,这种方法也为蒸馏带来的局限性,即我们只能对模型进行模块化的裁剪,如只裁剪整个层或者整个注意力头。如果要裁剪隐层神经元个数,就不能使用这个方法了。如果实际项目服务对于精度要求还是比较高的,那么建议使用这种方式。

一步到位不一定有效

这个技巧是看了论文miniLM发现的,论文中它的最终目标是将模型裁剪到4层,hidden_size裁剪一半。实际操作时,它并非直接使用蒸馏训练一个最小模型,而是先用原始模型蒸馏一个中介模型,其层数为4层,但是hidden_size不变,然后使用这个中介模型作为teacher模型来蒸馏得到最终的模型。我尝试了这种方式,发现有一定的效果,为了蒸馏得到4层的模型,我先将原始模型蒸馏到6层,然后再蒸馏到4层。这种方式比直接蒸馏小模型能够有3-4个百分点的提升。当然,我这里要说明一点,我比较的是训练相同epoch数下的两个模型的精度,也有可能是一步到位蒸馏小模型需要更多的训练步数才能达到收敛,并不能直接断定一步到位为训练法一定就比较差,但至少在相同的训练成本下,采用中介过渡是更有效的。

小结

这段时间通过对模型压缩相关技术的研究,获取了很多有效的模型压缩方法,并且在项目实际运用中产生了一定的效果,使得bert模型在性能提升40%左右的情况下,其精度能够保持在较高的水准。接下来会使用更多的下游任务进行验证尝试不同的蒸馏方法,也会持续关注更多的模型压缩方法,比如最近新出的fastBert。当然,我也会开辟新的实验工作,尝试使用BERT做NLG相关的任务,敬请期待。




本文由作者授权AINLP原创发布于公众号平台,欢迎投稿,AI、NLP均可。原文链接,点击"阅读原文"直达:


https://zhuanlan.zhihu.com/p/124215760


推荐阅读

模型压缩实践系列之——layer dropout

模型压缩实践系列之——bert-of-theseus,一个非常亲民的bert压缩方法

AINLP年度阅读收藏清单

中文命名实体识别工具(NER)哪家强?

学自然语言处理,其实更应该学好英语

斯坦福大学NLP组Python深度学习自然语言处理工具Stanza试用

太赞了!Springer面向公众开放电子书籍,附65本数学、编程、机器学习、深度学习、数据挖掘、数据科学等书籍链接及打包下载

数学之美中盛赞的 Michael Collins 教授,他的NLP课程要不要收藏?

自动作诗机&藏头诗生成器:五言、七言、绝句、律诗全了

这门斯坦福大学自然语言处理经典入门课,我放到B站了

征稿启示 | 稿费+GPU算力+星球嘉宾一个都不少

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLPer(id:ainlper),备注工作/研究方向+加群目的。


登录查看更多
2

相关内容

专知会员服务
73+阅读 · 2020年5月21日
图卷积神经网络蒸馏知识,Distillating Knowledge from GCN
专知会员服务
94+阅读 · 2020年3月25日
【Amazon】使用预先训练的Transformer模型进行数据增强
专知会员服务
56+阅读 · 2020年3月6日
模型压缩究竟在做什么?我们真的需要模型压缩么?
专知会员服务
27+阅读 · 2020年1月16日
深度神经网络模型压缩与加速综述
专知会员服务
128+阅读 · 2019年10月12日
基于知识蒸馏的BERT模型压缩
大数据文摘
18+阅读 · 2019年10月14日
7个实用的深度学习技巧
机器学习算法与Python学习
16+阅读 · 2019年3月6日
FAIR&MIT提出知识蒸馏新方法:数据集蒸馏
机器之心
7+阅读 · 2019年2月7日
干货——图像分类(下)
计算机视觉战队
14+阅读 · 2018年8月28日
入门 | 深度学习模型的简单优化技巧
机器之心
9+阅读 · 2018年6月10日
浅入浅出深度学习理论与实践
机器学习研究会
5+阅读 · 2018年2月28日
详解深度学习中的Normalization,不只是BN(1)
PaperWeekly
5+阅读 · 2018年2月6日
[学习] 这些深度学习网络调参技巧,你了解吗?
菜鸟的机器学习
7+阅读 · 2017年7月30日
Teacher-Student Training for Robust Tacotron-based TTS
Meta-Learning to Cluster
Arxiv
17+阅读 · 2019年10月30日
Knowledge Distillation from Internal Representations
Arxiv
4+阅读 · 2019年10月8日
Arxiv
15+阅读 · 2019年9月11日
Sparse Sequence-to-Sequence Models
Arxiv
5+阅读 · 2019年5月14日
Arxiv
9+阅读 · 2018年10月24日
Physical Primitive Decomposition
Arxiv
4+阅读 · 2018年9月13日
VIP会员
相关资讯
基于知识蒸馏的BERT模型压缩
大数据文摘
18+阅读 · 2019年10月14日
7个实用的深度学习技巧
机器学习算法与Python学习
16+阅读 · 2019年3月6日
FAIR&MIT提出知识蒸馏新方法:数据集蒸馏
机器之心
7+阅读 · 2019年2月7日
干货——图像分类(下)
计算机视觉战队
14+阅读 · 2018年8月28日
入门 | 深度学习模型的简单优化技巧
机器之心
9+阅读 · 2018年6月10日
浅入浅出深度学习理论与实践
机器学习研究会
5+阅读 · 2018年2月28日
详解深度学习中的Normalization,不只是BN(1)
PaperWeekly
5+阅读 · 2018年2月6日
[学习] 这些深度学习网络调参技巧,你了解吗?
菜鸟的机器学习
7+阅读 · 2017年7月30日
相关论文
Teacher-Student Training for Robust Tacotron-based TTS
Meta-Learning to Cluster
Arxiv
17+阅读 · 2019年10月30日
Knowledge Distillation from Internal Representations
Arxiv
4+阅读 · 2019年10月8日
Arxiv
15+阅读 · 2019年9月11日
Sparse Sequence-to-Sequence Models
Arxiv
5+阅读 · 2019年5月14日
Arxiv
9+阅读 · 2018年10月24日
Physical Primitive Decomposition
Arxiv
4+阅读 · 2018年9月13日
Top
微信扫码咨询专知VIP会员