We introduce JAX MD, a software package for performing differentiable physics simulations with a focus on molecular dynamics. JAX MD includes a number of physics simulation environments, as well as interaction potentials and neural networks that can be integrated into these environments without writing any additional code. Since the simulations themselves are differentiable functions, entire trajectories can be differentiated to perform meta-optimization. These features are built on primitive operations, such as spatial partitioning, that allow simulations to scale to hundreds-of-thousands of particles on a single GPU. These primitives are flexible enough that they can be used to scale up workloads outside of molecular dynamics. We present several examples that highlight the features of JAX MD including: integration of graph neural networks into traditional simulations, meta-optimization through minimization of particle packings, and a multi-agent flocking simulation. JAX MD is available at www.github.com/google/jax-md.
翻译:我们引入了JAX MD, 这是一种以分子动态为焦点进行不同物理模拟的软件包。 JAX MD 包含许多物理模拟环境,以及互动潜力和神经网络,这些潜力和神经网络可以在不写任何额外的代码的情况下融入这些环境中。由于这些模拟本身是不同的功能,因此整个轨迹可以有区别地进行元优化。这些特征建在原始操作上,例如空间分隔,使模拟能够在单一的GPU上达到千千粒子。这些原始物质具有足够的灵活性,可以用来在分子动态之外扩大工作量。我们举几个例子突出JAX MD的特征,包括:将图形神经网络纳入传统的模拟,通过尽量减少粒子包装实现元化,以及多试剂的传动模拟。JAX MD可以在www.github.com/google/jax-md上查阅。