This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks.
翻译:本文介绍了JaxPruner,一种基于JAX的修剪和稀疏训练的开源库,旨在通过提供流行的修剪和稀疏训练算法的简洁实现,以最小化内存和延迟开销,加速稀疏神经网络的研究。在JaxPruner中实现的算法使用通用API,并与Optax等流行的优化库无缝配合,从而实现与现有基于JAX的库的轻松集成。我们通过在四个不同的代码库(Scenic、t5x、Dopamine和FedJAX)中提供示例并在流行的基准测试上提供基准实验来演示这种易于集成性。