Exploiting low-precision computations has become a standard strategy in deep learning to address the growing computational costs imposed by ever larger models and datasets. However, naively performing all computations in low precision can lead to roundoff errors and instabilities. Therefore, mixed precision training schemes usually store the weights in high precision and use low-precision computations only for whitelisted operations. Despite their success, these principles are currently not reliable for training continuous-time architectures such as neural ordinary differential equations (Neural ODEs). This paper presents a mixed precision training framework for neural ODEs, combining explicit ODE solvers with a custom backpropagation scheme, and demonstrates its effectiveness across a range of learning tasks. Our scheme uses low-precision computations for evaluating the velocity, parameterized by the neural network, and for storing intermediate states, while stability is provided by a custom dynamic adjoint scaling and by accumulating the solution and gradients in higher precision. These contributions address two key challenges in training neural ODE: the computational cost of repeated network evaluations and the growth of memory requirements with the number of time steps or layers. Along with the paper, we publish our extendable, open-source PyTorch package rampde, whose syntax resembles that of leading packages to provide a drop-in replacement in existing codes. We demonstrate the reliability and effectiveness of our scheme using challenging test cases and on neural ODE applications in image classification and generative models, achieving approximately 50% memory reduction and up to 2x speedup while maintaining accuracy comparable to single-precision training.
翻译:利用低精度计算已成为深度学习中的标准策略,以应对日益增大的模型和数据集所带来的计算成本。然而,简单地将所有计算以低精度执行可能导致舍入误差和不稳定性。因此,混合精度训练方案通常将权重存储在高精度,并仅对白名单操作使用低精度计算。尽管这些方案已取得成功,但目前对于训练连续时间架构(如神经常微分方程)尚不可靠。本文提出了一种神经常微分方程的混合精度训练框架,结合显式ODE求解器与自定义反向传播方案,并在多种学习任务中验证其有效性。我们的方案使用低精度计算来评估由神经网络参数化的速度场并存储中间状态,而稳定性通过自定义动态伴随缩放以及在高精度下累积解和梯度来保证。这些贡献解决了训练神经常微分方程的两个关键挑战:重复网络评估的计算成本以及随时间步或层数增加的内存需求增长。随本文一同发布的是我们可扩展的开源PyTorch包rampde,其语法与主流包相似,可作为现有代码的直接替代方案。我们通过具有挑战性的测试案例以及在图像分类和生成模型中的神经常微分方程应用,证明了该方案的可靠性和有效性,在保持与单精度训练相当的精度同时,实现了约50%的内存减少和最高2倍的加速。