实践教程 | 解决pytorch半精度amp训练nan问题

2021 年 12 月 15 日 极市平台
↑ 点击 蓝字  关注极市平台

作者 | 可可哒@知乎(已授权) 
来源 | https://zhuanlan.zhihu.com/p/443166496 
编辑 | 极市平台

极市导读

 

本文主要是收集了一些在使用pytorch自带的amp下loss nan的情况及对应处理方案。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

Why?

如果要解决问题,首先就要明确原因:为什么全精度训练时不会nan,但是半精度就开始nan?这其实分了三种情况:

  1. 计算loss 时,出现了除以0的情况
  2. loss过大,被半精度判断为inf
  3. 网络参数中有nan,那么运算结果也会输出nan

1&2我想放到后面讨论,因为其实大部分报nan都是第三种情况。这里来先看看3。什么情况下会出现情况3?这个讨论给出了不错的解释:

【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

给大家翻译翻译:在使用ce loss 或者 bceloss的时候,会有log的操作,在半精度情况下,一些非常小的数值会被直接舍入到0,log(0)等于啥?——等于nan啊!

于是逻辑就理通了:回传的梯度因为log而变为nan->网络参数nan-> 每轮输出都变成nan。(;´Д`)

How?

问题定义清楚,那解决方案就非常简单了,只需要在涉及到log计算时,把输入从half精度转回float32:

x = x.float()
x_sigmoid = torch.sigmoid(x)

一些思考&废话

这里我接着讨论下我第一次看到nan之后,企图直接copy别人的解决方案,但解决不掉时踩过的坑。比如:

  1. 修改优化器的eps

有些blog会建议你从默认的1e-8 改为 1e-3,比如这篇:【pytorch1.1 半精度训练 Adam RMSprop 优化器 Nan 问题】https://link.zhihu.com/?target=https%3A//blog.csdn.net/gwb281386172/article/details/104705195

经过上面的分析,我们就能知道为什么这种方法不行——这个方案是针对优化器的数值稳定性做的修改,而loss计算这一步在优化器之前,如果loss直接nan,优化器的eps是救不回来的(托腮)。

那么这个方案在哪些场景下有效?——在loss输出不是nan时(感觉说了一句废话)。optimizer的eps是保证在进行除法backwards时,分母不出现0时需要加上的微小量。在半精度情况下,分母加上1e-8就仿佛听君一席话,因此,需要把eps调大一点。

  1. 聊聊amp的GradScaler

GradScaler是autocast的好伙伴,在官方教程上就和autocast配套使用:

from torch.cuda.amp import autocast, GradScaler
...
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()

with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()

scaler.step(optimizer)
scaler.update()

具体原理不是我这篇文章讨论的范围,网上很多教程都说得很清楚了,比如这个就不错:

【Gemfield:PyTorch的自动混合精度(AMP)】https://zhuanlan.zhihu.com/p/165152789

但是我这里想讨论另一点:scaler.step(optimizer)的运行原理。

在初始化GradScaler的时候,有一个参数enabled,值默认为True。如果为True,那么在调用scaler方法时会做梯度缩放来调整loss,以防半精度状况下,梯度值过大或者过小从而被nan或者inf。而且,它还会判断本轮loss是否是nan,如果是,那么本轮计算的梯度不会回传,同时,当前的scale系数乘上backoff_factor,缩减scale的大小_。_

那么,为什么这一步已经判断了loss是不是nan,还是会出现网络损失持续nan的情况呢?

这时我们就得再往前思考一步了:为什么loss会变成nan?回到文章一开始说的:

(1)计算loss 时,出现了除以0的情况;

(2)loss过大,被半精度判断为inf;

(3)网络直接输出了nan。

(1)&(2),其实是可以通过scaler.step(optimizer)解决的,分别由optimizer和scaler帮我们捕捉到了nan的异常。但(3)不行,(3)意味着部分甚至全部的网络参数已经变成nan了。这可能是在更之前的梯度回传过程中除以0导致的——首先【回传的梯度不是nan】,所以scaler不会捕捉异常;其次,由于使用了半精度,optimizer接收到了【已经因为精度损失而变为nan的loss】,nan不管加上多大的eps,都还是nan,所以optimizer也无法处理异常,最终导致网络参数nan。

所以3,只能通过本文一开始提出的方案来解决。其实,大部分分类问题在使用半精度时出现nan的情况都是第3种情况,也只能通过把精度转回为float32,或者在计算log时加上微小量来避免(但这样会损失精度)。

参考

【Nan Loss with torch.cuda.amp and CrossEntropyLoss】https://discuss.pytorch.org/t/nan-loss-with-torch-cuda-amp-and-crossentropyloss/108554/17

如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取 最新CV干货

公众号后台回复“transformer”获取最新Transformer综述论文下载~


极市干货
课程/比赛: 珠港澳人工智能算法大赛 保姆级零基础人工智能教程
算法trick 目标检测比赛中的tricks集锦 从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述: 一文弄懂各种loss function 工业图像异常检测最新研究总结(2019-2020)


CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart4)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~



觉得有用麻烦给个在看啦~   
登录查看更多
1

相关内容

《华为智慧农业解决方案》21页PPT
专知会员服务
120+阅读 · 2022年3月23日
《5G+智慧农业解决方案》22页PPT,三昇农业
专知会员服务
51+阅读 · 2022年3月23日
专知会员服务
51+阅读 · 2021年6月17日
【干货书】PyTorch实战-一个解决问题的方法
专知会员服务
144+阅读 · 2021年4月2日
最新《自动微分》综述教程,71页ppt
专知会员服务
21+阅读 · 2020年11月22日
最新《自动微分手册》77页pdf
专知会员服务
99+阅读 · 2020年6月6日
【ICIP2019教程-NVIDIA】图像到图像转换,附7份PPT下载
专知会员服务
53+阅读 · 2019年11月20日
实操教程|Pytorch常用损失函数拆解
极市平台
3+阅读 · 2022年1月6日
解决PyTorch半精度(AMP)训练nan问题
CVer
3+阅读 · 2022年1月4日
实践教程 | 浅谈 PyTorch 中的 tensor 及使用
极市平台
1+阅读 · 2021年12月14日
实践教程|PyTorch训练加速技巧
极市平台
0+阅读 · 2021年11月15日
让PyTorch训练速度更快,你需要掌握这17种方法
机器之心
1+阅读 · 2021年1月17日
一次 PyTorch 的踩坑经历,以及如何避免梯度成为NaN
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
1+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
5+阅读 · 2011年12月31日
国家自然科学基金
1+阅读 · 2008年12月31日
Arxiv
0+阅读 · 2022年4月20日
Arxiv
0+阅读 · 2022年4月19日
Arxiv
0+阅读 · 2022年4月15日
Arxiv
0+阅读 · 2022年4月15日
VIP会员
相关VIP内容
《华为智慧农业解决方案》21页PPT
专知会员服务
120+阅读 · 2022年3月23日
《5G+智慧农业解决方案》22页PPT,三昇农业
专知会员服务
51+阅读 · 2022年3月23日
专知会员服务
51+阅读 · 2021年6月17日
【干货书】PyTorch实战-一个解决问题的方法
专知会员服务
144+阅读 · 2021年4月2日
最新《自动微分》综述教程,71页ppt
专知会员服务
21+阅读 · 2020年11月22日
最新《自动微分手册》77页pdf
专知会员服务
99+阅读 · 2020年6月6日
【ICIP2019教程-NVIDIA】图像到图像转换,附7份PPT下载
专知会员服务
53+阅读 · 2019年11月20日
相关资讯
实操教程|Pytorch常用损失函数拆解
极市平台
3+阅读 · 2022年1月6日
解决PyTorch半精度(AMP)训练nan问题
CVer
3+阅读 · 2022年1月4日
实践教程 | 浅谈 PyTorch 中的 tensor 及使用
极市平台
1+阅读 · 2021年12月14日
实践教程|PyTorch训练加速技巧
极市平台
0+阅读 · 2021年11月15日
让PyTorch训练速度更快,你需要掌握这17种方法
机器之心
1+阅读 · 2021年1月17日
一次 PyTorch 的踩坑经历,以及如何避免梯度成为NaN
相关基金
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
1+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
5+阅读 · 2011年12月31日
国家自然科学基金
1+阅读 · 2008年12月31日
Top
微信扫码咨询专知VIP会员