We implement a trust region method on the GPU for nonlinear least squares curve fitting problems using a new deep learning Python library called JAX. Our open source package, JAXFit, works for both unconstrained and constrained curve fitting problems and allows the fit functions to be defined in Python alone -- without any specialized knowledge of either the GPU or CUDA programming. Since JAXFit runs on the GPU, it is much faster than CPU based libraries and even other GPU based libraries, despite being very easy to use. Additionally, due to JAX's deep learning foundations, the Jacobian in JAXFit's trust region algorithm is calculated with automatic differentiation, rather than than using derivative approximations or requiring the user to define the fit function's partial derivatives.
翻译:我们使用名为 JAX 的新的深层学习 Python 库,在 GPU 上对非线性最小平方曲线安装问题实施信任区域方法。 我们的开放源码软件包JAXFit 用于解决不受约束和受制约的曲线安装问题, 并允许仅在Python 中定义适合的功能 -- -- 没有关于 GPU 或 CUDA 程序的专门知识。 由于 JAXFit 运行在 GPU 上, 它比基于 CPU 的图书馆甚至其他基于 GPU 的图书馆要快得多, 尽管使用起来非常容易。 此外, 由于 JAX 的深学习基础, Jacobian 在 JAXFit 的托管区域算法是自动区分的, 而不是使用衍生的近似值或要求用户定义适合功能的部分衍生物。