TensorFlow 中 RNN 实现的正确打开方式

2017 年 9 月 13 日 AI研习社 何之源


本文作者何之源,原文载于知乎专栏 AI Insight,雷锋网 AI 科技评论获其授权发布。


上周写的文章《完全图解 RNN、RNN 变体、Seq2Seq、Attention 机制》介绍了一下 RNN 的几种结构,今天就来聊一聊如何在 TensorFlow 中实现这些结构。这篇文章的主要内容为:

  • 一个完整的、循序渐进的学习 TensorFlow 中 RNN 实现的方法。这个学习路径的曲线较为平缓,应该可以减少不少学习精力,帮助大家少走弯路。

  • 一些可能会踩的坑

  • TensorFlow 源码分析

  • 一个 Char RNN 实现示例,可以用来写诗,生成歌词,甚至可以用来写网络小说!(项目地址:https://github.com/hzy46/Char-RNN-TensorFlow)

  一、学习单步的 RNN:RNNCell

如果要学习 TensorFlow 中的 RNN,第一站应该就是去了解 “RNNCell”,它是 TensorFlow 中实现 RNN 的基本单元,每个 RNNCell 都有一个 call 方法,使用方式是:(output, next_state) = call(input, state)。

借助图片来说可能更容易理解。假设我们有一个初始状态 h0,还有输入 x1,调用 call(x1, h0) 后就可以得到 (output1, h1):

再调用一次 call(x2, h1) 就可以得到 (output2, h2):

也就是说,每调用一次 RNNCell 的 call 方法,就相当于在时间上 “推进了一步”,这就是 RNNCell 的基本功能。

在代码实现上,RNNCell 只是一个抽象类,我们用的时候都是用的它的两个子类 BasicRNNCell 和 BasicLSTMCell。顾名思义,前者是 RNN 的基础类,后者是 LSTM 的基础类。这里推荐大家阅读其源码实现(地址:http://t.cn/RNJrfMl),一开始并不需要全部看一遍,只需要看下 RNNCell、BasicRNNCell、BasicLSTMCell 这三个类的注释部分,应该就可以理解它们的功能了。

除了 call 方法外,对于 RNNCell,还有两个类属性比较重要:

  • state_size

  • output_size

前者是隐层的大小,后者是输出的大小。比如我们通常是将一个 batch 送入模型计算,设输入数据的形状为 (batch_size, input_size),那么计算时得到的隐层状态就是 (batch_size, state_size),输出就是 (batch_size, output_size)。

可以用下面的代码验证一下(注意,以下代码都基于 TensorFlow 最新的 1.2 版本):

import tensorflow as tf

import numpy as np

cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128) # state_size = 128

print(cell.state_size) # 128

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = cell.zero_state(32, np.float32) # 通过 zero_state 得到一个全 0 的初始状态,形状为 (batch_size, state_size)

output, h1 = cell.call(inputs, h0) #调用 call 函数

print(h1.shape) # (32, 128)

对于 BasicLSTMCell,情况有些许不同,因为 LSTM 可以看做有两个隐状态 h 和 c,对应的隐层就是一个 Tuple,每个都是 (batch_size, state_size) 的形状:

import tensorflow as tf

import numpy as np

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = lstm_cell.zero_state(32, np.float32) # 通过 zero_state 得到一个全 0 的初始状态

output, h1 = lstm_cell.call(inputs, h0)

print(h1.h)  # shape=(32, 128)

print(h1.c)  # shape=(32, 128)

  二、学习如何一次执行多步:tf.nn.dynamic_rnn

基础的 RNNCell 有一个很明显的问题:对于单个的 RNNCell,我们使用它的 call 函数进行运算时,只是在序列时间上前进了一步。比如使用 x1、h0 得到 h1,通过 x2、h1 得到 h2 等。这样的 h 话,如果我们的序列长度为 10,就要调用 10 次 call 函数,比较麻烦。对此,TensorFlow 提供了一个 tf.nn.dynamic_rnn 函数,使用该函数就相当于调用了 n 次 call 函数。即通过 {h0,x1, x2, …., xn} 直接得 {h1,h2…,hn}。

具体来说,设我们输入数据的格式为 (batch_size, time_steps, input_size),其中 time_steps 表示序列本身的长度,如在 Char RNN 中,长度为 10 的句子对应的 time_steps 就等于 10。最后的 input_size 就表示输入数据单个序列单个时间维度上固有的长度。另外我们已经定义好了一个 RNNCell,调用该 RNNCell 的 call 函数 time_steps 次,对应的代码就是:

# inputs: shape = (batch_size, time_steps, input_size)

# cell: RNNCell

# initial_state: shape = (batch_size, cell.state_size)。初始状态。一般可以取零矩阵

outputs, state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)

此时,得到的 outputs 就是 time_steps 步里所有的输出。它的形状为 (batch_size, time_steps, cell.output_size)。state 是最后一步的隐状态,它的形状为 (batch_size, cell.state_size)。

此处建议大家阅读 tf.nn.dynamic_rnn 的文档(地址:https://www.tensorflow.org/api_docs/python/tf/nn/dynamic_rnn)做进一步了解。

  三、学习如何堆叠 RNNCell:MultiRNNCell

很多时候,单层 RNN 的能力有限,我们需要多层的 RNN。将 x 输入第一层 RNN 的后得到隐层状态 h,这个隐层状态就相当于第二层 RNN 的输入,第二层 RNN 的隐层状态又相当于第三层 RNN 的输入,以此类推。在 TensorFlow 中,可以使用 tf.nn.rnn_cell.MultiRNNCell 函数对 RNNCell 进行堆叠,相应的示例程序如下:

import tensorflow as tf

import numpy as np

# 每调用一次这个函数就返回一个 BasicRNNCell

def get_a_cell():
   return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

# 用 tf.nn.rnn_cell MultiRNNCell 创建 3 层 RNN

cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)]) # 3 层 RNN

# 得到的 cell 实际也是 RNNCell 的子类

# 它的 state_size 是 (128, 128, 128)

# (128, 128, 128) 并不是 128x128x128 的意思

# 而是表示共有 3 个隐层状态,每个隐层状态的大小为 128

print(cell.state_size) # (128, 128, 128)

# 使用对应的 call 函数

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size

h0 = cell.zero_state(32, np.float32) # 通过 zero_state 得到一个全 0 的初始状态

output, h1 = cell.call(inputs, h0)

print(h1) # tuple 中含有 3 个 32x128 的向量

通过 MultiRNNCell 得到的 cell 并不是什么新鲜事物,它实际也是 RNNCell 的子类,因此也有 call 方法、state_size 和 output_size 属性。同样可以通过 tf.nn.dynamic_rnn 来一次运行多步。

此处建议阅读 MutiRNNCell 源码(地址:http://t.cn/RNJrfMl)中的注释进一步了解其功能。

  四、可能遇到的坑 1:Output 说明

在经典 RNN 结构中有这样的图:

在上面的代码中,我们好像有意忽略了调用 call 或 dynamic_rnn 函数后得到的 output 的介绍。将上图与 TensorFlow 的 BasicRNNCell 对照来看。h 就对应了 BasicRNNCell 的 state_size。那么,y 是不是就对应了 BasicRNNCell 的 output_size 呢?答案是否定的。

找到源码中 BasicRNNCell 的 call 函数实现:

def call(self, inputs, state):
   """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
   output = self._activation(_linear([inputs, state], self._num_units, True))
   return output, output

这句 “return output, output” 说明在 BasicRNNCell 中,output 其实和隐状态的值是一样的。因此,我们还需要额外对输出定义新的变换,才能得到图中真正的输出 y。由于 output 和隐状态是一回事,所以在 BasicRNNCell 中,state_size 永远等于 output_size。TensorFlow 是出于尽量精简的目的来定义 BasicRNNCell 的,所以省略了输出参数,我们这里一定要弄清楚它和图中原始 RNN 定义的联系与区别。

再来看一下 BasicLSTMCell 的 call 函数定义(函数的最后几行):

new_c = (
   c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))

new_h = self._activation(new_c) * sigmoid(o)

if self._state_is_tuple:
 new_state = LSTMStateTuple(new_c, new_h)

else:
 new_state = array_ops.concat([new_c, new_h], 1)

return new_h, new_state

我们只需要关注 self._state_is_tuple == True 的情况,因为 self._state_is_tuple == False 的情况将在未来被弃用。返回的隐状态是 new_c 和 new_h 的组合,而 output 就是单独的 new_h。如果我们处理的是分类问题,那么我们还需要对 new_h 添加单独的 Softmax 层才能得到最后的分类概率输出。

还是建议大家亲自看一下源码实现(地址:http://t.cn/RNJsJoH)来搞明白其中的细节。

  五、可能遇到的坑 2:因版本原因引起的错误

在前面我们讲到堆叠 RNN 时,使用的代码是:

# 每调用一次这个函数就返回一个 BasicRNNCell

def get_a_cell():
   return tf.nn.rnn_cell.BasicRNNCell(num_units=128)

# 用 tf.nn.rnn_cell MultiRNNCell 创建 3 层 RNN

cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)]) # 3 层 RNN

这个代码在 TensorFlow 1.2 中是可以正确使用的。但在之前的版本中(以及网上很多相关教程),实现方式是这样的:

one_cell =  tf.nn.rnn_cell.BasicRNNCell(num_units=128)

cell = tf.nn.rnn_cell.MultiRNNCell([one_cell] * 3) # 3 层 RNN

如果在 TensorFlow 1.2 中还按照原来的方式定义,就会引起错误!

  六、一个练手项目:Char RNN

上面的内容实际上就是 TensorFlow 中实现 RNN 的基本知识了。这个时候,建议大家用一个项目来练习巩固一下。此处特别推荐 Char RNN 项目,这个项目对应的是经典的 RNN 结构,实现它使用的 TensorFlow 函数就是上面说到的几个,项目本身又比较有趣,可以用来做文本生成,平常大家看到的用深度学习来写诗写歌词的基本用的就是它了。

Char RNN 的实现已经有很多了,可以自己去 Github 上面找,我这里也做了一个实现,供大家参考。项目地址为:hzy46/Char-RNN-TensorFlow(地址:https://github.com/hzy46/Char-RNN-TensorFlow)。代码的部分实现来自于《安娜卡列尼娜文本生成——利用 TensorFlow 构建 LSTM 模型》

这篇专栏,在此感谢 @天雨粟 。

我主要向代码中添加了 embedding 层,以支持中文,另外重新整理了代码结构,将 API 改成了最新的 TensorFlow 1.2 版本。

可以用这个项目来写诗(以下诗句都是自动生成的):

何人无不见,此地自何如。
一夜山边去,江山一夜归。
山风春草色,秋水夜声深。
何事同相见,应知旧子人。
何当不相见,何处见江边。
一叶生云里,春风出竹堂。
何时有相访,不得在君心。

还可以生成代码:

static int page_cpus(struct flags *str)
{
       int rc;
       struct rq *do_init;
};

/*


* Core_trace_periods the time in is is that supsed,
*/
#endif

/*


* Intendifint to state anded.
*/
int print_init(struct priority *rt)
{       /* Comment sighind if see task so and the sections */
       console(string, &can);
}


此外生成英文更不是问题(使用莎士比亚的文本训练):

LAUNCE:
The formity so mistalied on his, thou hast she was
to her hears, what we shall be that say a soun man
Would the lord and all a fouls and too, the say,
That we destent and here with my peace.

PALINA:


Why, are the must thou art breath or thy saming,
I have sate it him with too to have me of
I the camples.


最后,如果你脑洞够大,还可以来做一些更有意思的事情,比如我用了著名的网络小说《斗破苍穹》训练了一个 RNN 模型,可以生成下面的文本:

闻言,萧炎一怔,旋即目光转向一旁的那名灰袍青年,然后目光在那位老者身上扫过,那里,一个巨大的石台上,有着一个巨大的巨坑,一些黑色光柱,正在从中,一道巨大的黑色巨蟒,一股极度恐怖的气息,从天空上暴射而出 ,然后在其中一些一道道目光中,闪电般的出现在了那些人影,在那种灵魂之中,却是有着许些强者的感觉,在他们面前,那一道道身影,却是如同一道黑影一般,在那一道道目光中,在这片天地间,在那巨大的空间中,弥漫而开……

“这是一位斗尊阶别,不过不管你,也不可能会出手,那些家伙,可以为了这里,这里也是能够有着一些异常,而且他,也是不能将其他人给你的灵魂,所以,这些事,我也是不可能将这一个人的强者给吞天蟒,这般一次,我们的实力,便是能够将之击杀……”

“这里的人,也是能够与魂殿强者抗衡。”

萧炎眼眸中也是掠过一抹惊骇,旋即一笑,旋即一声冷喝,身后那些魂殿殿主便是对于萧炎,一道冷喝的身体,在天空之上暴射而出,一股恐怖的劲气,便是从天空倾洒而下。

“嗤!”

还是挺好玩的吧,另外还尝试了生成日文等等。

  七、学习完整版的 LSTMCell

上面只说了基础版的 BasicRNNCell 和 BasicLSTMCell。TensorFlow 中还有一个 “完全体” 的 LSTM:LSTMCell。这个完整版的 LSTM 可以定义 peephole,添加输出的投影层,以及给 LSTM 的遗忘单元设置 bias 等,可以参考其源码(地址:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py#L417)了解使用方法。

  八、学习最新的 Seq2Seq API

Google 在 TensorFlow 的 1.2 版本(1.3.0 的 rc 版已经出了,貌似正式版也要出了,更新真是快)中更新了 Seq2Seq API,使用这个 API 我们可以不用手动地去定义 Seq2Seq 模型中的 Encoder 和 Decoder。此外它还和 1.2 版本中的新数据读入方式 Datasets 兼容。可以阅读此处的文档(地址:http://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq)学习它的使用方法。

  九、总结

最后简单地总结一下,这篇文章提供了一个学习 TensorFlow RNN 实现的详细路径,其中包括了学习顺序、可能会踩的坑、源码分析以及一个示例项目 hzy46/Char-RNN-TensorFlow(地址:https://github.com/hzy46/Char-RNN-TensorFlow),希望能对大家有所帮助。

AI 研习社长期接受优秀文章投稿

同时免费为优质企业推广招聘信息

有意者请联系 jiazhilong@leiphone.com



后台回复 “我要进群” 加入 AI 技术讨论群 



新人福利



关注 AI 研习社(okweiwu),回复  1  领取

【超过 1000G 神经网络 / AI / 大数据,教程,论文】



如何用 TensorFlow 生成令人惊艳的分形图案

▼▼▼

登录查看更多
0

相关内容

RNN:循环神经网络,是深度学习的一种模型。
【实用书】学习用Python编写代码进行数据分析,103页pdf
专知会员服务
195+阅读 · 2020年6月29日
一份循环神经网络RNNs简明教程,37页ppt
专知会员服务
173+阅读 · 2020年5月6日
Transformer文本分类代码
专知会员服务
117+阅读 · 2020年2月3日
【电子书】Flutter实战305页PDF免费下载
专知会员服务
23+阅读 · 2019年11月7日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
165+阅读 · 2019年10月28日
初学者的 Keras:实现卷积神经网络
Python程序员
24+阅读 · 2019年9月8日
CNN图像风格迁移的原理及TensorFlow实现
数据挖掘入门与实战
5+阅读 · 2018年4月18日
tensorflow LSTM + CTC实现端到端OCR
数据挖掘入门与实战
8+阅读 · 2017年11月15日
tensorflow系列笔记:流程,概念和代码解析
北京思腾合力科技有限公司
30+阅读 · 2017年11月11日
十分钟掌握Keras实现RNN的seq2seq学习
机器学习研究会
10+阅读 · 2017年10月13日
Do RNN and LSTM have Long Memory?
Arxiv
19+阅读 · 2020年6月10日
Arxiv
92+阅读 · 2020年2月28日
Arxiv
3+阅读 · 2018年10月25日
Knowledge Based Machine Reading Comprehension
Arxiv
4+阅读 · 2018年9月12日
Arxiv
7+阅读 · 2018年6月1日
Arxiv
3+阅读 · 2018年6月1日
Arxiv
27+阅读 · 2017年12月6日
Arxiv
5+阅读 · 2017年11月13日
VIP会员
相关VIP内容
【实用书】学习用Python编写代码进行数据分析,103页pdf
专知会员服务
195+阅读 · 2020年6月29日
一份循环神经网络RNNs简明教程,37页ppt
专知会员服务
173+阅读 · 2020年5月6日
Transformer文本分类代码
专知会员服务
117+阅读 · 2020年2月3日
【电子书】Flutter实战305页PDF免费下载
专知会员服务
23+阅读 · 2019年11月7日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
165+阅读 · 2019年10月28日
相关资讯
初学者的 Keras:实现卷积神经网络
Python程序员
24+阅读 · 2019年9月8日
CNN图像风格迁移的原理及TensorFlow实现
数据挖掘入门与实战
5+阅读 · 2018年4月18日
tensorflow LSTM + CTC实现端到端OCR
数据挖掘入门与实战
8+阅读 · 2017年11月15日
tensorflow系列笔记:流程,概念和代码解析
北京思腾合力科技有限公司
30+阅读 · 2017年11月11日
十分钟掌握Keras实现RNN的seq2seq学习
机器学习研究会
10+阅读 · 2017年10月13日
相关论文
Do RNN and LSTM have Long Memory?
Arxiv
19+阅读 · 2020年6月10日
Arxiv
92+阅读 · 2020年2月28日
Arxiv
3+阅读 · 2018年10月25日
Knowledge Based Machine Reading Comprehension
Arxiv
4+阅读 · 2018年9月12日
Arxiv
7+阅读 · 2018年6月1日
Arxiv
3+阅读 · 2018年6月1日
Arxiv
27+阅读 · 2017年12月6日
Arxiv
5+阅读 · 2017年11月13日
Top
微信扫码咨询专知VIP会员