【干货】初学者指南|实现LSTM

【导读】时间序列建模目前已广泛应用在机器翻译、语音识别等相关领域,是目前AI领域不可或缺的重要技术。本文预测比特币价格为例,从头教你搭建长短期记忆网络LSTM。


作者 | Brian Mwangi

编译 | 专知

整理 | Xiaowen


 

A Beginner’s Guide to Implementing Long Short-Term Memory Networks (LSTM)


人类的思想是持久的,这使我们能够理解模式,这反过来又使我们有能力预测下一步的行动。你对这篇文章的理解将建立在你读过的前几个单词的基础上。反复出现的神经网络复制了这个概念。

 

RNNs是一种人工神经网络,它能够识别和预测数据序列,如文本、基因组、手写、口语或数字时间序列数据。它们的循环允许一致的信息流,可以处理任意长度的序列。

 

使用内部状态(内存)来处理一系列输入,RNNs已经被用于解决一些问题:

  • 语言翻译与建模

  • 语音识别

  • 图像标题

  • 时间序列数据,如股票价格,告诉你什么时候买入或卖出

  • 自主驾驶系统,以预见汽车轨迹,并帮助避免事故

 

我写这篇文章的前提是,你对神经网络有了基本的理解。如果你需要复习,请参考【1】。

 


理解循环神经网络


为了理解RNNs,让我们使用一个带有一个隐藏层的简单感知器网络,这样的网络可以很好地处理简单的分类问题。当增加更多的隐藏层时,我们的网络将能够在输入数据中推断出更复杂的序列,并提高预测精度。

 

RNN网络结构:

A:神经网络

Xt:输入

ht:输出

 

循环确保信息流的一致性。A(神经网络块)基于输入的Xt产生输出ht。

 

RNN也可以被视为同一网络的多个副本,每个副本将信息传递给它的后续网络。



在每个时间步骤(t)中,递归神经元接收来自前一时间步骤ht-1的输入Xt以及它自己的输出。

 

如果你想深入研究RNN,我强烈建议一些很好的资源,它们包括:

  • Introduction to Recurrent Neural Networks.【2】

  • Recurrent Neural Networks for Beginners.【3】

  • Introduction to RNNs 【4】

 

RNNs有一个很大的缺陷,叫做消失梯度;也就是说,它们在学习远程依赖关系方面有困难(实体之间的关系有几步之隔)。

 

假设2014年12月的比特币价格是350美元,我们希望正确预测2018年4月和5月的比特币价格。使用RNNs,由于长期记忆不足,我们的模型无法准确预测这几个月的价格。为了解决这个问题,我们开发了一种特殊的RNN,称为长短期记忆单元(LSTM)。

 

什么是长短期记忆单元?


这是一个用来记忆长期依赖的特殊神经元。LSTM包含一个内部状态变量,它从一个单元传递到另一个单元,并由操作门(Operation Gates)修改(我们将在我们的示例中讨论这个问题)。

 

LSTM非常聪明,它可以决定保存旧信息的时间、记忆和遗忘的时间,以及如何在旧记忆和新输入之间建立联系。要深入了解LSTMs,这里有一个很好的资源:Understanding LSTM networks【5】。

 

实现LSTM


在我们的例子中,我们将使用LSTMs实现时间序列分析,预测从2014年12月到2018年5月比特币的价格。我一直使用CryptoDataDownload【6】,因为它简单直观。我使用了Google的CoLab开发环境,因为它设置环境简单,并且加速免费GPU,这减少了训练时间。如果你是CoLab的新手,这里有一个初学者指南【7】。比特币.csv文件和这个例子的全部代码可以从我的GitHub配置文件【8】中获得。

 

什么是时间序列分析?


在这里,历史数据被用来识别现有的数据模式,并使用它们来预测未来会发生什么。要获得详细的理解,请参阅本指南【9】。

 

导入库


我们将与各种库一起工作,这些库必须首先安装在CoLab笔记本中,然后导入到我们的环境中。

 

加载数据


btc.csv数据集包含比特币的价格和数量,我们使用以下命令将其加载到工作环境中:


目标变量


我们将选择比特币收盘价作为我们的目标变量来预测。

 

数据预处理


Sklearn包含预处理模块,它允许我们缩放数据,然后将其放入我们的模型中。

 

绘制数据


现在让我们来看看比特币在特定时期内的收盘价走势。


特征和标签数据集


此函数用于创建数据集的特性和标签。


Input: data ——我们正在使用的数据集


Window_size ——我们将使用多少个数据点来预测序列中的下一个数据点(例如,如果Window_size=7,我们将使用前7天来预测今天的比特币价格)。


Outputs: X ——将特征拆分为数据点的窗口(如果windows_size=1,x=[len(Data)-1,1])。


y—labels ——这是我们试图预测的序列中的下一个数字。



训练和测试数据集


将数据分解为训练集和测试集对于得到模型性能的真实估计是至关重要的,我们使用了80%(1018)的数据集作为训练集,其余的20%(248)作为验证集。

 

定义网络


超参


超参数解释了模型的高层结构信息。

 

batch_Size-这是我们一次传递的数据窗口的数量。


window_Size-是我们为我们的案例考虑的预测比特币价格的天数。


hidden_layer — 这是我们在LSTM单元中使用的单元数。


clip_margin —— 这是为了防止梯度爆炸,我们使用裁剪器,以剪去梯度以上的这一边缘。learning_rate —— 这是一种旨在减少损失函数的优化方法。


epochs — 这是我们的模型需要进行的迭代次数(前向和反向传播)。

 

你可以为你的模型自定义各种超参数,但对于我们的示例,让我们继续我们定义的那些参数。


占位符


占位符允许我们使用tf.placeholder()命令在网络中发送不同的数据。


LSTM权重


LSTM的权重由操作门决定,包括:遗忘门(Forgetgate)、输入门(Input gate)和输出门(Output gate)。

 

遗忘门(Forget gate)


ft =σ(Wf[ht-1,Xt]+bf)


这是一个sigmoid层,它利用t-1的输出和时间t处的当前输入,将它们组合成一个单一的张量,然后应用线性变换,然后再进行sigmoid操作。

 

由于sigmoid的存在,门的输出介于0到1之间。这个数字与内部状态相乘,这就是为什么门被称为遗忘门。如果ft=0,则以前的内部状态完全被遗忘,而如果ft=1时,它将被不改变地传递。

 

输入门(Input gate)


it=σ(Wi[ht-1,Xt]+bi)


此状态将前一个输出与新输入一起,并将它们传递到另一个sigmoid层。此门返回介于0和1之间的值。然后,输入门的值与候选层的输出相乘。

 

Ct=tanh(Wi[ht-1,Xt]+bi)


该层将双曲切线应用于输入和先前输出的混合,返回候选向量。然后将候选向量添加到内部状态,内部状态将根据以下规则进行更新:

 

Ct=ft *Ct-1+it*Ct


前一状态乘以遗忘门,然后添加到输出门允许的新候选的分数。

 

输出门(Output gate)


Ot=σ(Wo[ht-1,Xt]+bo)


ht=Ot*tanh Ct

 

这个门控制内部状态有多少传递到输出,并以类似于其他门的方式工作。

 


网络循环


为网络创建了一个循环,它遍历批处理中的每个窗口,将batch_states置为全零。输出用于预测比特币价格。

 

定义损失函数


在这里,我们将使用mean_squared_error函数来最小化误差。

 

训练网络


我们现在用我们初始化的数据训练网络,然后观察损失随时间的变化。现在的损耗随着观察到的时间的增加而减小,提高了我们预测比特币价格的模型的准确性。


绘制预测



输出


我们的模型已经能够准确地预测比特币价格的基础上通过实施LSTMs单元的原始数据。通过将窗口长度从7天缩短到3天,可以提高模型性能。你可以调整完整代码以优化模型性能。


结论


我希望这篇文章让你在理解LSTM方面有了一个良好的开端。

 

1.https://www.kdnuggets.com/2016/11/quick-introduction-neural-networks

2.https://www.kdnuggets.com/2015/10/recurrent-neural-networks-tutorial

3.https://medium.com/@camrongodbout/recurrent-neural-networks-for-beginners-7aca4e933b82

4.http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns

5.http://colah.github.io/posts/2015-08-Understanding-LSTMs

6.http://www.cryptodatadownload.com/

7.https://medium.com/deep-learning-turkey/google-colab-free-gpu-tutorial-e113627b9f5d

8.https://github.com/brynmwangy/predicting-bitcoin-prices-using-LSTM

9.https://www.kdnuggets.com/2018/03/time-series-dummies-3-step-process.html

 

原文链接:

https://heartbeat.fritz.ai/a-beginners-guide-to-implementing-long-short-term-memory-networks-lstm-eb7a2ff09a27

-END-

专 · 知

人工智能领域26个主题知识资料全集获取加入专知人工智能服务群: 欢迎微信扫一扫加入专知人工智能知识星球群,获取专业知识教程视频资料和与专家交流咨询!


请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料

请加专知小助手微信(扫一扫如下二维码添加),加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~

关注专知公众号,获取人工智能的专业知识!

点击“阅读原文”,使用专知

展开全文
Top
微信扫码咨询专知VIP会员