【导读】深度学习框架TensorFlow不仅在学术界得到了普及,在工业界也有非常广泛的应用。日常我们接触到的TensorFlow的用法大多为基于Python的实验用法,并不能直接用于工业界的线上产品。本文介绍一种简单的发布TensorFlow模型的方法。
在工业产品中使用TensorFlow模型的方法
在工业产品中TensorFlow大概有下面几种使用方法:
用TensorFlow的C++/Java/Nodejs API直接使用保存的TensorFlow模型:类似Caffe,适合做桌面软件。
直接将使用TensorFlow的Python代码放到Flask等Web程序中,提供Restful接口:实现和调试方便,但效率不太高,不大适合高负荷场景,且没有版本管理、模型热更新等功能。
将TensorFlow模型托管到TensorFlow Serving中,提供RPC或Restful服务:实现方便,高效,自带版本管理、模型热更新等,很适合大规模线上业务。
本文介绍的是方法3,如何用最简单的方法将TensorFlow发布到TensorFlow Serving中。
一句代码保存TensorFlow模型
# coding=utf-8
import tensorflow as tf
# 模型版本号
model_version = 1
# 定义模型
x = tf.placeholder(tf.float32, shape=[None, 4], name="x")
y = tf.layers.dense(x, 10, activation=tf.nn.softmax)
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
# 模型训练过程,省略
# ......
# 保存训练好的模型到"model/版本号"中
tf.saved_model.simple_save(
session=sess,
export_dir="model/{}".format(model_version),
inputs={"x": x},
outputs={"y": y}
)
代码中除了最后一句,其它部分都是常规的TensorFlow代码,模型定义、进入Session、模型训练等。代码的最后用tf.saved_model.simple_save将模型保存为SavedModel。注意,这里将模型保存在了"model/版本号"文件夹中,而不是直接保存在了"model"文件夹中,这是因为TensorFlow Serving要求在模型目录下加一层版本目录,来进行版本维护、热更新等:
安装TensorFlow Serving
方法一:用apt-get安装
对于Ubuntu或Debian(Bash on Windows10也可以),可以使用apt-get安装Tensorflow Serving。先用下面的命令添加软件源:
echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
添加成功后可直接用apt-get进行安装:
apt-get update && apt-get install tensorflow-model-server
方法二:用Docker安装
TensorFlow Serving官方提供了Docker容器,可以一键安装:
docker pull tensorflow/serving
将模型发布到TensorFlow Serving中
下面的方法基于在本机使用apt-get安装TensorFlow Serving的方法。对于Docker用户,需要将模型挂载或复制到Docker中,按照Docker中的路径来执行下面的教程。
用下面这行命令,就可以启动TensorFlow Serving,并将刚才保存的模型发布到TensorFlow Serving中。注意,这里的模型所在路径是刚才"model"目录的路径,而不是"model/版本号"目录的路径,因为TensorFlow Serving认为用户的模型所在路径中包含了多个版本的模型。
tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=模型名 --model_base_path=模型所在路径
客户端可以用GRPC和Restful两种方式来调用TensorFlow Serving,这里我们介绍基于Restful的方法,可以看到,命令中指定的Restful服务端口为8501,我们可以用curl命令来查看服务的状态:
curl http://localhost:8501/v1/models/model
执行结果:
{
"model_version_status": [
{
"version": "1",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": ""
}
}
]
}
下面我们用curl向TensorFlow Serving发送一个输入x=[1.1, 1.2, 0.8, 1.3],来获取预测的输出信息y:
curl -d '{"instances": [[1.1,1.2,0.8,1.3]]}' -X POST http://localhost:8501/v1/models/模型名:predict
服务器返回的结果如下:
{
"predictions": [[0.0649088, 0.0974758, 0.0456831, 0.297224, 0.152209, 0.0177431, 0.104193, 0.0450511, 0.13074, 0.044771]]
}
我们的模型成功地输出了y=[0.0649088, 0.0974758, 0.0456831, 0.297224, 0.152209, 0.0177431, 0.104193, 0.0450511, 0.13074, 0.044771]
这里我们使用的是curl命令,在实际工程中,使用requests(Python)、OkHttp(Java)等Http请求库可以用类似的方法方便地请求TensorFlow Serving来获取模型的预测结果。
版本维护和模型热更新
刚才我们将模型保存在了"model/1"中,其中1是模型的版本号。如果我们的算法工程师研发出了更好的模型,此时我们并不需要将TensorFlow Serving重启,只需要将新模型发布在"model/新版本号"中,如"model/2"。TensorFlow Serving就会自动发布新版本的模型,客户端也可以请求新版本对应的API了。
-END-
专 · 知
人工智能领域26个主题知识资料全集获取与加入专知人工智能服务群: 欢迎微信扫一扫加入专知人工智能知识星球群,获取专业知识教程视频资料和与专家交流咨询!
PC登录www.zhuanzhi.ai或者点击阅读原文,可以获取更多AI知识资料!
加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)可以其他同行一起交流~ 请加专知小助手微信(扫一扫如下二维码添加),
请关注专知公众号,获取人工智能的专业知识!
点击“阅读原文”,使用专知