【前言】:在前面的内容里,我们已经学习了循环神经网络的基本结构和运算过程,这一小节里,我们将用TensorFlow实现简单的RNN,并且用来解决时序数据的预测问题,看一看RNN究竟能达到什么样的效果,具体又是如何实现的。
在这个演示项目里,我们使用随机生成的方式生成一个数据集(由0和1组成的二进制序列),然后人为的增加一些数据间的关系。最后我们把这个数据集放进RNN里,让RNN去学习其中的关系,实现二进制序列的预测1。数据生成的方式如下:
循环生成规模为五十万的数据集,每次产生的数据为0或1的概率均为0.5。如果连续生成了两个1(或两个0)的话,则下一个数据强制为0(或1)。
1. 我们首先导入需要的Python模块:
2. 定义一个Data类,用来产生数据:
3. 在构造方法“__init__”中,我们初始化了数据集的大小“data_size”、一个batch的大小“batch_size”、一个epoch中的batch数目“num_batch”以及RNN的时间步“time_step”。接下来我们定义一个“generate_data”方法:
在第11行代码中,我们用了“np.random.choice”函数生成的由0和1组成的长串数据。接下来我们用了一个for循环,在“data_without_rel”保存的数据的基础上重新生成了一组数据,并保存在“data_with_rel”数组中。为了使生成的数据间具有一定的序列关系,我们使用了前面介绍的很简单的数据生成方式:以“data_without_rel”中的数据为参照,如果出现了连续两个1(或0)则生成一个0(或1),其它情况则以相等概率随机生成0或1。
有了数据我们接下来要用RNN去学习这些数据,看看它能不能学习到我们产生这些数据时使用的策略,即数据间的联系。评判RNN是否学习到规律以及学习的效果如何的依据,是我们在第三章里介绍过的交叉熵损失函数。根据我们生成数据的规则,如果RNN没有学习到规则,那么它预测正确的概率就是0.5,否则它预测正确的概率为:0.5*0.5+0.5*1=0.75(在“data_without_rel”中,连续出现的两个数字的组合为:00、01、10和11。00和11出现的总概率占0.5,在这种情况下,如果RNN学习到了规律,那么一定能预测出下一个数字,00对应1,11对应0。而如果出现的是01或10的话,RNN预测正确的概率就只有0.5,所以综合起来就是0.75)。
根据交叉熵损失函数,在没有学习到规律的时候,其交叉熵损失为:
loss = - (0.5 * np.log(0.5) + 0.5 * np.log(0.5)) = 0.6931471805599453
在学习到规律的时候,其交叉熵损失为:
Loss = -0.5*(0.5 * np.log(0.5) + np.log(0.5))
=-0.25 * (1 * np.log(1) ) - 0.25 * (1 *np.log(1))
=0.34657359027997264
4. 我们定义“generate_epochs”方法处理生成的数据:
5. 接下来实现RNN部分:
6. 定义RNN模型:
这里我们使用了“dynamic_rnn”,因此每次会同时处理所有batch的第一组数据,总共处理的次数为:batch_size / time_step。
7. 到这里,我们已经实现了整个RNN模型,接下来初始化相关数据,看看RNN的学习效果如何:
定义数据集的大小为500000,每个batch的大小为2000,RNN的“时间步”设为5,隐藏层的神经元数目为6。将训练过程中的loss可视化,结果如下图中的左侧图像所示:
图1 二进制序列数据训练的loss曲线
从左侧loss曲线可以看到,loss最终稳定在了0.35左右,这与我们之前的计算结果一致,说明RNN学习到了序列数据中的规则。右侧的loss曲线是在调整了序列关系的时间间隔后(此时的time_step过小,导致RNN无法学习到序列数据的规则)的结果,此时loss稳定在0.69左右,与之前的计算也吻合。
下一篇,我们将介绍几种常见的RNN循环神经网络结构以及部分代码示例。
本专题更多相关文章,请查看:
博客 | Tensorflow系列专题(七):一文综述RNN循环神经网络
博客 | Tensorflow系列专题(六):实战项目Mnist手写数据集识别
博客 | Tensorflow系列专题(四):神经网络篇之前馈神经网络综述
欢迎扫码关注磐创AI微信公众号
独家中文版 CMU 秋季深度学习课程免费开学!
CMU 2018 秋季《深度学习导论》为官方开源最新版本,由卡耐基梅隆大学教授 Bhiksha Raj 授权 AI 研习社翻译。学员将在本课程中学习深度神经网络的基础知识,以及它们在众多 AI 任务中的应用。课程结束后,期望学生能对深度学习有足够的了解,并且能够在众多的实际任务中应用深度学习。
↗扫码即可免费学习↖
点击 阅读原文 查看本文更多内容↙