干货|如何保存和恢复TensorFlow训练的模型?

2017 年 10 月 31 日 全球人工智能

——免费加入AI技术专家社群>>

——免费加入AI高管投资者群>>

摘要: 深度学习小技巧掌握:作者通过一个简单的例子详细介绍了如何将训练过程中的深度学习模型保存,然后如何加载有了这个小技巧,再也不用担心在训练模型中出错了


如果深层神经网络模型的复杂度非常高的话,那么训练它可能需要相当长的一段时间,当然这也取决于你拥有的数据量,运行模型的硬件等等。在大多数情况下,你需要通过保存文件来保障你试验的稳定性,防止如果中断(或一个错误),你能够继续从没有错误的地方开始。

更重要的是,对于任何深度学习的框架,像TensorFlow,在成功的训练之后,你需要重新使用模型的学习参数来完成对新数据的预测。

在这篇文章中,我们来看一下如何保存和恢复TensorFlow模型,我们在此介绍一些最有用的方法,并提供一些例子。

1. 首先我们将快速介绍TensorFlow模型

TensorFlow 的主要功能的英文通过张量来传递其基本数据结构类似于与NumPy中的多维数组,而图表则表示数据计算。的英文它一个符号库,这意味着定义图形和张量将仅创建一个模型,而张电子杂志的量具体值状语从句:操作将在会话(会话)中执行,会话(会话)一种在图中执行建模操作的机制。会话关闭时,张量的任何具体值都会丢失,这也是运行会话后将模型保存到文件的另一个原因。

通过示例可以帮助我们更容易理解,所以让我们为二维数据的线性回归创建一个简单的TensorFlow模型。

首先,我们将导入我们的库:

下一步是创建模型我们将生成一个模型,它将以以下的形式估算二次函数的水平和垂直位移:

其中h的英文水平状语从句:v的英文垂直的变化。以下是如何生成模型的过程(有关详细信息,请参阅代码中的注释):

在创建模型的过程中,我们需要有一个在会话中运行的模型,并且传递一些真实的数据。我们生成一些二次数据(二次数据),并给他们添加噪声。


节省课

Saver类是TensorFlow 库提供的类,它是保存图形结构和变量的首选方法

2.1 保存模型

在以下几行代码中,定义我们一个Saver对象,并在train_graph()函数中,经过100次迭代的方法最小化成本函数。然后,在每次迭代中以及优化完成后,将模型保存到磁盘。每个保存在磁盘上创建二进制文件被称为“检查点”。

现在让我们用上述功能训练模型,并打印出训练的参数。

好的,参数是非常准确的。如果我们检查我们的文件系统,最后4次迭代中保存有文件以及最终的模型。

保存模型时,你会注意到需要4种类型的文件才能保存:

“.META” 文件:包含图形结构。

“数据” 文件:包含变量的值。

“的.index” 文件:标识检查点。

“检查点” 文件:具有最近检查点列表的协议缓冲区。


图1:检查点文件保存到磁盘

调用tf.train.Saver()方法,如上所示,将所有变量保存到一个文件。通过将它们作为参数,表情通过列表或dict传递来保存变量的子集,例如:tf.train.Saver({'hor_estimate': h_est})

Saver构造函数的一些其他有用的参数,也可以控制整个过程,它们是:

1.max_to_keep:最多保留的检查点数。

2.keep_checkpoint_every_n_hours:保存检查点的时间间隔。

如果你想要了解更多信息,请查看官方文档Saver类,它提供了其它有用的信息,你可以探索查看。

3.重建模型

恢复TensorFlow模型时要做的第一件事就是将图形结构从“.META”文件加载到当前图形中。

也可以使用以下命令探索当前图形tf.get_default_graph()。接着第二步是加载变量的值。提醒:值仅存在于会话(会话)中。

如前面所提到的,这种方法只保存图形结构和变量,这意味着通过占位符“X”和“Y”输入的训练数据不会被保存。无论如何,在这个例子中,我们将使用我们定义的训练数据tf ,并且可视化模型拟合。


Saver这个类允许使用一个简单的方法来保存和恢复你的TensorFlow模型(图形和变量)到/从文件,并保留你工作中的多个检查点,这可能是有用的,它可以帮助你的模型在训练过程中进行微调。

4.SavedModel 格式(格式)

在TensorFlow中保存和恢复模型的一种新方法是使用SavedModel ,Builder 和loader功能。这个方法实际上是Saver提供的更高级别的序列化,它更适合于商业目的。

这种虽然SavedModel方法似乎不被开发人员完全接受,但它的创作者指出:显然它的英文未来与Saver主要关注变量的类相比,SavedModel尝试将一些有用的功能包含在一个包中,例如Signatures:网求允许保存具有一组输入和输出的图形,Assets:包含初始化中使用的外部文件。

4.1 使用SavedModel Builder保存模型

我们接下来尝试使用SavedModelBuilder类完成模型的保存。在我们的示例中,我们不使用任何符号,但也足以说明该过程。

运行此代码时,你会注意到我们的模型已保存到位于“./SavedModel/saved_model.pb”的文件中。

4.2 使用SavedModel Loader程序恢复模型

恢复模型使用tf.saved_model.loader并且可以恢复会话范围中保存的变量,符号。

在下面的例子中,我们将加载模型,并打印出我们的两个系数(h_estv_est)的数值。数值如预期的那样,我们的模型已经被成功地恢复了

5. 结论

如果你知道你的深度学习网络的训练可能会花费很长时间,保存和恢复TensorFlow模型是非常有用的功能。该主题太广泛,无法在一篇博客文章中详细介绍。不管怎样,在这篇文章中我们介绍了两个工具:Saver,并创建一个文件结构,使用简单的线性回归来说明实例。希望这些能够帮助到你训练出更好的神经网络模型。

热门文章推荐

黑科技|Adobe出图象技术神器!视频也可以PS了!!

史上第一个被授予公民身份的机器人索菲亚和人对答如流!

浙大90后女黑客在GeekPwn2017上秒破人脸识别系统!

周志华点评AlphaGo Zero:这6大特点非常值得注意!

汤晓鸥教授:人工智能让天下没有难吹的牛!

英伟达发布全球首款人工智能全自动驾驶平台

未来 3~5 年内,哪个方向的机器学习人才最紧缺?

中科院步态识别技术:不看脸 50米内在人群中认出你!

厉害|黄仁勋狂怼CPU:摩尔定律已死 未来属于GPU!

干货|7步让你从零开始掌握Python机器学习!

登录查看更多
1

相关内容

Google发布的第二代深度学习系统TensorFlow
【高能所】如何做好⼀份学术报告& 简单介绍LaTeX 的使用
Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
129+阅读 · 2020年3月15日
《深度学习》圣经花书的数学推导、原理与Python代码实现
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
69+阅读 · 2020年1月17日
KGCN:使用TensorFlow进行知识图谱的机器学习
专知会员服务
81+阅读 · 2020年1月13日
【模型泛化教程】标签平滑与Keras, TensorFlow,和深度学习
专知会员服务
20+阅读 · 2019年12月31日
初学者的 Keras:实现卷积神经网络
Python程序员
24+阅读 · 2019年9月8日
Tensorflow框架是如何支持分布式训练的?
AI100
9+阅读 · 2019年3月26日
一个小例子带你轻松Keras图像分类入门
云栖社区
4+阅读 · 2018年1月24日
TensorFlow神经网络教程
Python程序员
4+阅读 · 2017年12月4日
深度学习入门篇--手把手教你用 TensorFlow 训练模型
全球人工智能
4+阅读 · 2017年10月21日
如何为LSTM重新构建输入数据(Keras)
全球人工智能
6+阅读 · 2017年10月13日
深度学习实战(二)——基于Keras 的深度学习
乐享数据DataScientists
15+阅读 · 2017年7月13日
TensorFlow学习笔记2:构建CNN模型
黑龙江大学自然语言处理实验室
3+阅读 · 2016年6月14日
Learning Implicit Fields for Generative Shape Modeling
Arxiv
10+阅读 · 2018年12月6日
dynnode2vec: Scalable Dynamic Network Embedding
Arxiv
14+阅读 · 2018年12月6日
Arxiv
3+阅读 · 2018年6月1日
Arxiv
10+阅读 · 2018年3月23日
Arxiv
10+阅读 · 2018年2月17日
VIP会员
相关资讯
初学者的 Keras:实现卷积神经网络
Python程序员
24+阅读 · 2019年9月8日
Tensorflow框架是如何支持分布式训练的?
AI100
9+阅读 · 2019年3月26日
一个小例子带你轻松Keras图像分类入门
云栖社区
4+阅读 · 2018年1月24日
TensorFlow神经网络教程
Python程序员
4+阅读 · 2017年12月4日
深度学习入门篇--手把手教你用 TensorFlow 训练模型
全球人工智能
4+阅读 · 2017年10月21日
如何为LSTM重新构建输入数据(Keras)
全球人工智能
6+阅读 · 2017年10月13日
深度学习实战(二)——基于Keras 的深度学习
乐享数据DataScientists
15+阅读 · 2017年7月13日
TensorFlow学习笔记2:构建CNN模型
黑龙江大学自然语言处理实验室
3+阅读 · 2016年6月14日
Top
微信扫码咨询专知VIP会员