DeepMind发布神经网络、强化学习库,网友:推动JAX发展

2020 年 2 月 21 日 量子位
十三 发自 凹非寺
量子位 报道 | 公众号 QbitAI

DeepMind今日发布了HaikuRLax两个库,都是基于JAX。

JAX由谷歌提出,是TensorFlow的简化库。结合了针对线性代数的编译器XLA,和自动区分本地 Python 和 Numpy 代码的库Autograd,在高性能的机器学习研究中使用。

而此次发布的两个库,分别针对神经网络强化学习,大幅简化了JAX的使用。

Haiku是基于JAX的神经网络库,允许用户使用熟悉的面向对象程序设计模型,可完全访问 JAX 的纯函数变换。

RLax是JAX顶层的库,它提供了用于实现增强学习代理的有用构件。

有意思的是,Reddit网友惊奇的发现Haiku这个库的名字,竟然不以“ax”结尾。

当然,也有网友对这两个库表示了肯定:

毫无疑问,对JAX起到了推动作用。

那么,我们就来看下Haiku和RLex的庐山真面目吧。

Haiku

Haiku是JAX的神经网络库,它允许用户使用熟悉的面向对象编程模型,同时允许完全访问JAX的纯函数转换。

它提供了两个核心工具:模块抽象hk.Module,和一个简单的函数转换hk.transform

hk.Module是Python对象,包含对其自身参数、其他模块和对用户输入应用函数方法的引用。

hk.transform允许完全访问JAX的纯函数转换。

其实,在JAX中有许多神经网络库,那么Haiku有什么特别之处呢?有5点。

1、Haiku已经由DeepMind的研究人员进行了大规模测试

DeepMind相对容易地在Haiku和JAX中复制了许多实验。其中包括图像和语言处理的大规模结果、生成模型和强化学习。

2、Haiku是一个库,而不是一个框架

它的设计是为了简化一些具体的事情,包括管理模型参数和其他模型状态。可以与其他库一起编写,并与JAX的其他部分一起工作。

3、Haiku并不是另起炉灶

它建立在Sonnet的编程模型和API之上,Sonnet是DeepMind几乎普遍采用的神经网络库。它保留了Sonnet用于状态管理的基于模块的编程模型,同时保留了对JAX函数转换的访问。

4、过渡到Haiku是比较容易的

通过精心的设计,从TensorFlow和Sonnet,过渡到JAX和Haiku是比较容易的。除了新的函数(如hk.transform),Haiku的目的是Sonnet 2的API。

5、Haiku简化了JAX

它提供了一个处理随机数的简单模型。在转换后的函数中,hk.next_rng_key()返回一个唯一的rng键。

那么,该如何安装Haiku呢?

Haiku是用纯Python编写的,但是通过JAX依赖于c++代码。

首先,按照下方链接中的说明,安装带有相关加速器支持的JAX。
https://github.com/google/jax#installation

然后,只需要一句简单的pip命令就可以完成安装。

$ pip install git+https://github.com/deepmind/haiku

接下来,是一个神经网络和损失函数的例子。

import haiku as hk
import jax.numpy as jnp

def softmax_cross_entropy(logits, labels):
  one_hot = hk.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels):
  model = hk.Sequential([
      hk.Linear(1000),
      jax.nn.relu,
      hk.Linear(100),
      jax.nn.relu,
      hk.Linear(10),
  ])
  logits = model(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_obj = hk.transform(loss_fn)

RLax

RLax是JAX顶层的库,它提供了用于实现增强学习代理的有用构件。

它所提供的操作和函数不是完整的算法,而是强化学习特定数学操作的实现。

RLax的安装也非常简单,一个pip命令就可以搞定。

pip install git+git://github.com/deepmind/rlax.git

使用JAX的jax.jit函数,所有的RLax代码可以不同的硬件上编译。

RLax需要注意的是它的命名规则。

许多函数在连续的时间步长中考虑策略、操作、奖励和值,以便计算它们的输出。在这种情况下,后缀_ttm1通常是为了说明每个输入是在哪个步骤上生成的,例如:

q_tm1:转换的源状态中的操作值。
a_tm1:在源状态下选择的操作。
r_t:在目标状态下收集的结果奖励。
q_t:目标状态下的操作值。

Haiku和RLax都已在GitHub上开源,有兴趣的读者可从“传送门”的链接访问。

传送门

Haiku:
https://github.com/deepmind/haiku

RLax:
https://github.com/deepmind/rlax

3期图像处理系列课程开始报名了~ 
2.27第一期课程,来自NVIDIA开发者社区的何琨老师,将带领大家学习如何利用NVIDIA迁移式学习工具包实现实时目标检测。
欢迎大家扫下图二维码报名,记得备注“英伟达”哦~

直播报名 | 图像与视频处理系列课程

新年福利 | 关注AI发展新动态 

内参新升级!拓展优质人脉,获取最新AI资讯&论文教程,欢迎加入AI内参社群一起学习~


量子位 QbitAI · 头条号签约作者


վ'ᴗ' ի 追踪AI技术和产品新动态


喜欢就点「在看」吧 !



登录查看更多
1

相关内容

【牛津大学&DeepMind】自监督学习教程,141页ppt
专知会员服务
179+阅读 · 2020年5月29日
【Google】利用AUTOML实现加速感知神经网络设计
专知会员服务
29+阅读 · 2020年3月5日
2019必读的十大深度强化学习论文
专知会员服务
58+阅读 · 2020年1月16日
【强化学习】深度强化学习初学者指南
专知会员服务
180+阅读 · 2019年12月14日
TensorFlow Lite 2019 年发展蓝图
谷歌开发者
6+阅读 · 2019年3月12日
官方解读:TensorFlow 2.0 新的功能特性
云头条
3+阅读 · 2019年1月23日
要替代 TensorFlow?谷歌开源机器学习库 JAX
新智元
3+阅读 · 2018年12月14日
TensorFlow 2.0和PyTorch谁更好?大牛们争了好几天
TensorFlow神经网络教程
Python程序员
4+阅读 · 2017年12月4日
DeepMind发布《星际争霸 II》深度学习环境
人工智能学家
8+阅读 · 2017年9月22日
详解TensorForce: 基于TensorFlow建立强化学习API
机械鸡
5+阅读 · 2017年7月22日
Question Generation by Transformers
Arxiv
5+阅读 · 2019年9月14日
One-Shot Federated Learning
Arxiv
9+阅读 · 2019年3月5日
Nocaps: novel object captioning at scale
Arxiv
6+阅读 · 2018年12月20日
Arxiv
7+阅读 · 2018年8月28日
Arxiv
10+阅读 · 2018年2月17日
Arxiv
12+阅读 · 2018年1月28日
VIP会员
相关资讯
TensorFlow Lite 2019 年发展蓝图
谷歌开发者
6+阅读 · 2019年3月12日
官方解读:TensorFlow 2.0 新的功能特性
云头条
3+阅读 · 2019年1月23日
要替代 TensorFlow?谷歌开源机器学习库 JAX
新智元
3+阅读 · 2018年12月14日
TensorFlow 2.0和PyTorch谁更好?大牛们争了好几天
TensorFlow神经网络教程
Python程序员
4+阅读 · 2017年12月4日
DeepMind发布《星际争霸 II》深度学习环境
人工智能学家
8+阅读 · 2017年9月22日
详解TensorForce: 基于TensorFlow建立强化学习API
机械鸡
5+阅读 · 2017年7月22日
相关论文
Question Generation by Transformers
Arxiv
5+阅读 · 2019年9月14日
One-Shot Federated Learning
Arxiv
9+阅读 · 2019年3月5日
Nocaps: novel object captioning at scale
Arxiv
6+阅读 · 2018年12月20日
Arxiv
7+阅读 · 2018年8月28日
Arxiv
10+阅读 · 2018年2月17日
Arxiv
12+阅读 · 2018年1月28日
Top
微信扫码咨询专知VIP会员