Exploiting low-precision computations has become a standard strategy in deep learning to address the growing computational costs imposed by ever larger models and datasets. However, naively performing all computations in low precision can lead to roundoff errors and instabilities. Therefore, mixed precision training schemes usually store the weights in high precision and use low-precision computations only for whitelisted operations. Despite their success, these principles are currently not reliable for training continuous-time architectures such as neural ordinary differential equations (Neural ODEs). This paper presents a mixed precision training framework for neural ODEs, combining explicit ODE solvers with a custom backpropagation scheme, and demonstrates its effectiveness across a range of learning tasks. Our scheme uses low-precision computations for evaluating the velocity, parameterized by the neural network, and for storing intermediate states, while stability is provided by a custom dynamic adjoint scaling and by accumulating the solution and gradients in higher precision. These contributions address two key challenges in training neural ODE: the computational cost of repeated network evaluations and the growth of memory requirements with the number of time steps or layers. Along with the paper, we publish our extendable, open-source PyTorch package rampde, whose syntax resembles that of leading packages to provide a drop-in replacement in existing codes. We demonstrate the reliability and effectiveness of our scheme using challenging test cases and on neural ODE applications in image classification and generative models, achieving approximately 50% memory reduction and up to 2x speedup while maintaining accuracy comparable to single-precision training.


翻译:利用低精度计算已成为深度学习中的标准策略,以应对日益增大的模型和数据集所带来的计算成本。然而,简单地将所有计算以低精度执行可能导致舍入误差和不稳定性。因此,混合精度训练方案通常将权重存储在高精度,并仅对白名单操作使用低精度计算。尽管这些方案已取得成功,但目前对于训练连续时间架构(如神经常微分方程)尚不可靠。本文提出了一种神经常微分方程的混合精度训练框架,结合显式ODE求解器与自定义反向传播方案,并在多种学习任务中验证其有效性。我们的方案使用低精度计算来评估由神经网络参数化的速度场并存储中间状态,而稳定性通过自定义动态伴随缩放以及在高精度下累积解和梯度来保证。这些贡献解决了训练神经常微分方程的两个关键挑战:重复网络评估的计算成本以及随时间步或层数增加的内存需求增长。随本文一同发布的是我们可扩展的开源PyTorch包rampde,其语法与主流包相似,可作为现有代码的直接替代方案。我们通过具有挑战性的测试案例以及在图像分类和生成模型中的神经常微分方程应用,证明了该方案的可靠性和有效性,在保持与单精度训练相当的精度同时,实现了约50%的内存减少和最高2倍的加速。

0
下载
关闭预览

相关内容

FlowQA: Grasping Flow in History for Conversational Machine Comprehension
专知会员服务
34+阅读 · 2019年10月18日
Keras François Chollet 《Deep Learning with Python 》, 386页pdf
专知会员服务
163+阅读 · 2019年10月12日
Transferring Knowledge across Learning Processes
CreateAMind
29+阅读 · 2019年5月18日
Unsupervised Learning via Meta-Learning
CreateAMind
44+阅读 · 2019年1月3日
meta learning 17年:MAML SNAIL
CreateAMind
11+阅读 · 2019年1月2日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
Arxiv
14+阅读 · 2024年5月28日
Arxiv
49+阅读 · 2021年5月9日
Domain Representation for Knowledge Graph Embedding
Arxiv
14+阅读 · 2019年9月11日
Deep Anomaly Detection with Outlier Exposure
Arxiv
17+阅读 · 2018年12月21日
Arxiv
15+阅读 · 2018年2月4日
VIP会员
相关资讯
Transferring Knowledge across Learning Processes
CreateAMind
29+阅读 · 2019年5月18日
Unsupervised Learning via Meta-Learning
CreateAMind
44+阅读 · 2019年1月3日
meta learning 17年:MAML SNAIL
CreateAMind
11+阅读 · 2019年1月2日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
相关论文
Arxiv
14+阅读 · 2024年5月28日
Arxiv
49+阅读 · 2021年5月9日
Domain Representation for Knowledge Graph Embedding
Arxiv
14+阅读 · 2019年9月11日
Deep Anomaly Detection with Outlier Exposure
Arxiv
17+阅读 · 2018年12月21日
Arxiv
15+阅读 · 2018年2月4日
相关基金
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
Top
微信扫码咨询专知VIP会员