JAX is widely used in machine learning and scientific computing, the latter of which often relies on existing high-performance code that we would ideally like to incorporate into JAX. Reimplementing the existing code in JAX is often impractical and the existing interface in JAX for binding custom code requires deep knowledge of JAX and its C++ backend. The goal of JAXbind is to drastically reduce the effort required to bind custom functions implemented in other programming languages to JAX. Specifically, JAXbind provides an easy-to-use Python interface for defining custom so-called JAX primitives that support arbitrary JAX transformations.
翻译:暂无翻译