BERT重计算:用22.5%的训练时间节省5倍的显存开销(附代码)

2020 年 4 月 21 日 AINLP

作者: 夕小瑶、rumor酱

前言

虽然TPU的显存令人羡慕,但是由于众所周知的原因,绝大部分人还是很难日常化使用的。英伟达又一直在挤牙膏,至今单卡的最大显存也仅仅到32G(参考V100、DGX-2)。然而,训练一个24层的BERT Large模型的时候,如果sequence length开满512,那么batch size仅仅开到8(有时候能到10)就把这寥寥32G的显存打满了。如果想训练一个48层乃至100层的BERT Large,那完全是土豪们的游戏了,需要疯狂的模型并行+分布式多机训练。

但!是!万能的小夕前不久在Daxiang Dong大佬的安利下,发现了@陈天奇 大佬2016年的一篇宝藏paper!

简单的划一下重点:

这篇paper用时间换空间的思想,在前向时只保存部分中间节点,在反向时重新计算没保存的部分。论文通过这种机制,在每个batch只多计算一次前向的情况下,把n层网络的占用显存优化到了 。在极端情况下,仍可用 的计算时间换取到 的显存占用。在论文的实验中,他们成功将将1000层的残差网络从48G优化到了7G。且,这种方法同样可以直接应用于RNN结构中。

看完摘要,瞬间感觉在小破卡上训练BERT Large有救了!!!

此外,来快速过一遍paper中最重要的三点结论:

  1. 梯度计算等价,理论上没有精度损失

  2. 可以节省4倍+的显存开销

  1. 训练速度仅仅会被拖慢30%

    image-20200420140806122

不过论文发表在2016年,当时还没有BERT,不过Baidu Paddle团队补了一个BERT的实验结果,发现在BERT上面只用22.5%的训练速度损失就能换来5倍+的显存开销节省!相关实验在本文末尾,不着急,接下来我们先一起分析一下在训练阶段时显存为什么容易不足。

感谢Baidu Paddle团队提供本节图文素材和测试数据

训练阶段显存为何不足

深度学习中,网络的一次训练包含前向计算、后向计算和优化三个步骤。

在这个过程中,前向计算会输出大量的隐层变量Tensor,当模型层数加深时,Tensor数量可达成千上万个。如Bert Large模型,单个Tensor可达到1GB,这些Tensor在显存中累积,显存很快就爆掉了╮( ̄▽ ̄"")╭

下图是Bert Large模型在一次训练过程中的显存使用情况,可以明显看到在前向计算过程中,显存累积趋势是一个陡峭的上升直线。而在反向计算过程中,这些隐层Tensor又会很快地被消耗掉,又是一个陡峭的下降曲线,显存直接降到低位。

那么问题来了,为什么不直接删除这些前向计算的Tensor呢?

答案很简单,因为这些隐层的Tensor在反向的时会被用到(手动狗头

来个简单的证明。

假设前向计算中有一个矩阵乘法计算:

Y = W × X

对W求梯度:

很容易发现,对W求梯度的公式里有X,而X就是那个巨能吃显存的隐层Tensor╮( ̄▽ ̄"")╭

那我们是否可以暂时扔掉这些隐层Tensor,在反向计算时再把它们重新生成出来呢?当然可以,这正是上面这篇paper的思想。

重计算

顾名思义,"重计算"就是让每个训练迭代过程做两次前向计算,看起来有点奇怪,实际上却非常有效!对于刚刚那个吃显存的Bert Large,支持重计算机制后,显存占用直接从175GB降低到20GB,陡峭的显存上升直线变成了缓慢增长的Z形曲线,如下图所示。

核心思想是将前向计算分割成多个段,将每个段的起始Tensor作为这个段的检查点(checkpoints)。前向计算时,除了检查点以外的其他隐层Tensor占有的显存可以及时释放。反向计算用到这些隐层Tensor时,从前一个检查点开始,重新进行这个段的前向计算,就可以重新获得隐层Tensor。

重计算机制有点像玩单机游戏。每过一个关卡就会保存一个检查点,而隐层Tensor就相当于游戏中任何一个时刻的图像。普通的训练方式是打通一遍游戏,并且将游戏中所有时刻的图像保存下来;而重计算机制的思路是先把游戏通关,保存检查点,后面当收到某一时刻图像的请求时,再重打一遍这一关卡就可以了。

如下图,举一个简单的例子,添加重计算机制前,前向计算中需要存储的隐层是4个红点;添加重计算机制后,需要存储的隐层变为2个蓝点, 从而节省了这部分内存。

虽然时间也是宝贵的,但重计算方法的性价比很高。在论文的实验中,作者用30%的计算时间换取了4倍的内存空间。并且重计算只是重复了一次前向的过程,理论上精度没有任何损失

那么这么宝藏的算法有没有开源实现呢?

开源实现

调研了一波,似乎TF没有原生支持,但是生态里有第三方实现;pytorch和paddlepaddle中都有原生API支持

  • Pytorch:
    • torch.utils.checkpoint
  • PaddlePaddle:
    • optimizer.RecomputeOptimizer

不过pytorch的文档比较略,也没有提供更细致的示例和相关数据,有兴趣的小伙伴自行试一下。paddle框架中提供了详细到哭的文档,甚至还有一个现成的BERT+重计算的例子,以及非常详细的实验测试结果。这里直接贴过来(真香系列

Paddle中实现显存重计算大体分为三步:

  1. 定义一个经典的优化器,如SGD优化器;
  2. 在外面包一层重计算优化器;
  3. 设置检查点。

以MLP为例,只需要增加两行代码就可以进入重计算模式

import paddle.fluid as fluid# 定义MLPdef mlp(input_x, input_y, hid_dim=128, label_dim=2):    print(input_x)    fc_1 = fluid.layers.fc(input=input_x, size=hid_dim)    prediction = fluid.layers.fc(input=[fc_1], size=label_dim, act='softmax')    cost = fluid.layers.cross_entropy(input=prediction, label=input_y)    sum_cost = fluid.layers.reduce_mean(cost)    return sum_cost, fc_1, prediction

input_x = fluid.layers.data(name="x", shape=[32], dtype='float32')input_y = fluid.layers.data(name="y", shape=[1], dtype='int64')cost, fc_1, pred = mlp(input_x, input_y)
# 定义RecomputeOptimizer
sgd = fluid.optimizer.SGD(learning_rate=0.01)recompute_optimizer = fluid.optimizer.RecomputeOptimizer(sgd)# 设置checkpointsrecompute_optimizer._set_checkpoints([fc_1, pred])# 运行优化算法recompute_optimizer.minimize(cost)

该示例github链接:https://github.com/PaddlePaddle/examples/blob/master/community_examples/recompute/demo.py

此外,官方还给出了一个BERT中做重计算的示例

github链接:https://github.com/PaddlePaddle/Fleet/tree/develop/examples/recompute/bert

BERT实验结论(划重点

根据上面paddle官方提供的BERT示例和实验结果,得出以下几个结论

结论一

在32GB显存的Tesla V100显卡上应用重计算机制,可以训练更大、更深的深度学习模型。当num_tokens为4096(batch size=32,seqlen=128)时,可以训练100层的Bert网络。

从Github的实验结果也可以看出,显存上的收益比速度的损失要大很多:

在batch_size上提升了5倍,速度只降低了约1/5,且精度没有损失。

结论二

模型训练的batch size最大可提升为原来的5倍+,且只有少量的速度损失。

重计算机制在Bert Large这一模型上收益最大,最大batch size从93提升到562!而在VGG-16这种比较浅的模型上,重计算机制的收益则比较小。这充分符合重计算机制的设计理念:为了训练更大、更深的模型。

结论三

在古董显卡Tesla K40显卡(12G显存)上,训练BERT Large时batch size可以开到130

最后,希望本文可以帮助大家在小破卡上尽情训练BERT Large~



推荐阅读

AINLP年度阅读收藏清单

世界读书日,我来凑个单,推荐几本NLP/推荐算法/广告系统/其他相关的新书

百度PaddleHub NLP模型全面升级,推理性能提升50%以上

斯坦福大学NLP组Python深度学习自然语言处理工具Stanza试用

当当的羊毛,快薅,这一次要拼手速!

数学之美中盛赞的 Michael Collins 教授,他的NLP课程要不要收藏?

自动作诗机&藏头诗生成器:五言、七言、绝句、律诗全了

From Word Embeddings To Document Distances 阅读笔记

模型压缩实践系列之——bert-of-theseus,一个非常亲民的bert压缩方法

这门斯坦福大学自然语言处理经典入门课,我放到B站了

可解释性论文阅读笔记1-Tree Regularization

征稿启示 | 稿费+GPU算力+星球嘉宾一个都不少

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLPer(id:ainlper),备注工作/研究方向+加群目的。



登录查看更多
0

相关内容

【ICML2020】统一预训练伪掩码语言模型
专知会员服务
25+阅读 · 2020年7月23日
[ICML-Google]先宽后窄:对深度薄网络的有效训练
专知会员服务
34+阅读 · 2020年7月5日
专知会员服务
44+阅读 · 2020年3月6日
BERT技术体系综述论文:40项分析探究BERT如何work
专知会员服务
139+阅读 · 2020年3月1日
模型压缩究竟在做什么?我们真的需要模型压缩么?
专知会员服务
27+阅读 · 2020年1月16日
绝对干货!NLP预训练模型:从transformer到albert
新智元
13+阅读 · 2019年11月10日
基于知识蒸馏的BERT模型压缩
大数据文摘
18+阅读 · 2019年10月14日
GPU 显存不足怎么办?
AINLP
13+阅读 · 2019年8月16日
快速上手笔记,PyTorch模型训练实用教程(附代码)
解读谷歌最强NLP模型BERT:模型、数据和训练
未来产业促进会
5+阅读 · 2018年10月20日
keras系列︱深度学习五款常用的已训练模型
数据挖掘入门与实战
10+阅读 · 2018年3月27日
前端高性能计算(4):GPU加速计算
前端大全
7+阅读 · 2017年10月26日
Knowledge Distillation from Internal Representations
Arxiv
4+阅读 · 2019年10月8日
Arxiv
12+阅读 · 2019年2月28日
Arxiv
15+阅读 · 2018年2月4日
VIP会员
相关资讯
绝对干货!NLP预训练模型:从transformer到albert
新智元
13+阅读 · 2019年11月10日
基于知识蒸馏的BERT模型压缩
大数据文摘
18+阅读 · 2019年10月14日
GPU 显存不足怎么办?
AINLP
13+阅读 · 2019年8月16日
快速上手笔记,PyTorch模型训练实用教程(附代码)
解读谷歌最强NLP模型BERT:模型、数据和训练
未来产业促进会
5+阅读 · 2018年10月20日
keras系列︱深度学习五款常用的已训练模型
数据挖掘入门与实战
10+阅读 · 2018年3月27日
前端高性能计算(4):GPU加速计算
前端大全
7+阅读 · 2017年10月26日
Top
微信扫码咨询专知VIP会员