output = net(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
with torch.cuda.amp.autocast():
output = net(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
>>> # FP32相加不会有问题。
>>> torch.tensor(2**-3) + torch.tensor(2**-14)
tensor(0.1251)
>>> # FP16相加,较小的数字会被忽略掉。因为在[2**-3, 2**-2]内,FP16表示的固定间隔为2**-13。
>>> # 也就是说比2**-3大的下一个数字为2**-3 + 2**-13,因此2**-14加了跟没加一样。
>>> # half()的作用是将FP32转化为FP16。
>>> torch.tensor(2**-3).half() + torch.tensor(2**-14).half()
tensor(0.1250, dtype=torch.float16)
>>> # 将2**-14换成2**-13就可以了。
>>> torch.tensor(2**-3).half() + torch.tensor(2**-13).half()
tensor(0.1251, dtype=torch.float16)
注意这里一定要先转成 FP32,不然 unscale 的时候还是会下溢出。
如何保证黑名单模块在 FP32 环境中运行:以 BN 层为例,将其权重转为 FP32,并且将输入从 FP16 转成 FP32,这样就可以保证整个模块是在 FP32 下运行的。
搞不懂 Tensor Core 是如何应用到 AMP 中的。有人说 Tensor Core 可以帮助我们利用 FP16 的梯度来更新 FP32 的模型权重。但是阅读了 apex 的源码之后,我发现 FP16 的梯度会先转化为 FP32,再做更新,所以权重更新和 Tensor Core 并无关系。以后弄明白了再回来补充吧。
图片来自:全网最全-混合精度训练原理
https://zhuanlan.zhihu.com/p/441591808
个人猜测 PyTorch 会让每个 Tensor 本身的数据类型和梯度的数据类型保持一致,虽然产生了 FP16 的梯度,但是因为权重本身是 FP32,所以框架会将梯度也转化为 FP32。
参考链接
[1] https://arxiv.org/abs/1710.03740
[2] https://github.com/NVIDIA/apex
[3] https://en.wikipedia.org/wiki/Round-off_error#Addition
[4] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/amp.py#L68
[5] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L40
[6] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/handle.py#L113
[7] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/_process_optimizer.py#L123
[8] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L202
[9] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L207
[10] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/handle.py#L128
[11] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L213
[12] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/_initialize.py#L179
[13] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/_initialize.py#L194
[14] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/_process_optimizer.py#L44
[15] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L40
[16] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/handle.py#L113
[17] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L94
[18] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L202
[19] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L207
[20] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/handle.py#L128
[21] https://github.com/NVIDIA/apex/blob/1403c21acf87b0f2245278309071aef17d80c13b/apex/amp/scaler.py#L213
[22] https://github.com/open-mmlab/mmcv/blob/f5425ab7611ab2376ddb478b57cb2f46f6054e13/mmcv/runner/hooks/optimizer.py#L344
[23] https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html
[24] https://pytorch.org/docs/stable/amp.html#autocast-op-reference
[25] PyTorch必备神器 | 唯快不破:基于Apex的混合精度加速
[26] https://zhuanlan.zhihu.com/p/103685761
[27] https://zhuanlan.zhihu.com/p/441591808
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧