如果我们要将 TensorFlow 模型嵌入移动设备,首先需要确定模型的格式。正如官方文档所述,TensorFlow 模型有很多格式,例如,Checkpoint,Exporter,SaveModel,Frozen graph 等。
Checkpoint 格式用于在训练时存模型快照以及恢复训练。所以不要将它作为最终的保存格式,尽管它包含了推理所需的全部变量。
SaceModel 格式通常用于基于 TensorFlow Serving 的在线服务。你可以使用 TensorFlow C++ API 和 Python API 加载这种格式的模型。而在实际应用中,在线客户端通常倾向于使用开源、轻量级、RESTful 的服务,即 Simple TensorFlow Serving。
对于移动设备,我们需要使用 GraphDef 对象和 Checkpoint 文件生成 Frozen graph。TensorFLow 提供了导出 GraphDef 对象的 API,你可以通过这段代码来轻松实现。
graph_file_name = "graph.pb"
tf.train.write_graph(sess.graph_def, FLAGS.model_path, graph_file_name, as_text=False)
然后我们可以使用 TensorFlow 库中的 freeze_graph.py 脚本生成二进制 protobuf 格式的 Frozen graph 文件。
python ./freeze_graph.py --input_graph=/Users/tobe/code/tensorflow_template_application/model/graph.pb --input_checkpoint=/Users/tobe/code/tensorflow_template_application/checkpoint/checkpoint.ckpt-200 --output_graph=./frozen_graph.pb --output_node_names=output_keys,output_prediction,output_softmax --input_binary=True
需要注意的是,我们应当为生成的模型指定可用于推理的输出节点名称,不同的 TensorFlow 应用可能会根据其用途更改输出点的名称。
如果你不想自己生成 Frozen graph 文件,可以直接克隆 teansorflow template application 的源代码,其中包含了移动端模型文件。
现在我们有了 TensorFlow 模型文件,接下来只需要加载模型并使用我们的数据进行推理就可以了。
多亏了 TensorFlow Mobile 的工作,我们不需要自己编写 C++ 和 JNI 代码来加载 TensorFlow 模型。有一个名为“TensorFlowInferenceInterface”的封装好的类可以用来加载模型并进行推理。
AssetManager assetManager = getAssets();
String MODEL_FILE = "file:///android_asset/tensorflow_template_application_model.pb";
TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(assetManager, MODEL_FILE);
众所周知,张量 (Tensor) 是 TensorFlow 的核心概念。因此我们需要使用自己构建的张量数据作为输入,而不是原始的图像像素,但是它主要用于 Python 和 C++ 接口。对于安卓客户端,我们需要构造与张量尺寸相同的“nd-array”作为输入。
int[] keysValues = new int[2];
keysValues[0] = 1;
keysValues[1] = 2;
float[] featruesValues = new float[18];
for (int i = 0; i < 18; i++) {
featruesValues[i] = 1f;
}
String[] inputNames = new String[2];
inputNames[0] = "keys";
inputNames[1] = "features";
inferenceInterface.feed(inputNames[0], keysValues, 2, 1);
inferenceInterface.feed(inputNames[1], featruesValues, 2, 9);
这是一种使用 TensorFlow 模型的普遍方式。如果你需要访问自然语言处理 (NLP) 模型或者图像模型,你也可以将输入解析为 Java 的"nd-array"格式。实例代码展示了怎样为一个 tensorflow_template_application 模型构建张量并针对你的模型修改实际数据。
最后,我们可以使用 Java 的“nd-array”对象进行推理并得到输出。
String[] outputNames = new String[3];
outputNames[0] = "output_keys";
outputNames[1] = "output_prediction";
outputNames[2] = "output_softmax";
inferenceInterface.run(outputNames, logStats);
int[] keysOutputs = new int[2];
long[] predictionOutput = new long[2];
float[] softmaxOutput = new float[4];
inferenceInterface.fetch(outputNames[0], keysOutputs);
inferenceInterface.fetch(outputNames[1], predictionOutput);
inferenceInterface.fetch(outputNames[2], softmaxOutput);
当然,请确保这些张量数据的类型和形状能够与 TensorFlow 的 Python 脚本所定义的模型兼容。得到了模型的输出后,你就可以使用这些输出数组实现其他功能了。
离线推理的完整过程不需要 gRPC 客户端或者任何网络连接。这是一个 TensorFlow 示例模型嵌入安卓设备的教程,不过你可以将它扩展到所有其他模型。本文中所有的代码都来自于 GitHub 上的开源项目。也许,你可能不知道 TensorFlow Mobile 的所有细节,但是仍可以通过 tesnorflow template application 随意尝试编译你自己的安卓客户端。
总的来说,TensorFlow 的移动端模型实现起来并不难。尽管官方的安卓 Demo 应用对于一般情况来说并不够好,但我们可以从这篇文章中学到一些实践经验,并有信心将我们所有的机器学习模型都移植到离线设备上。
如果觉得内容不错,记得给我们「留言」和「点赞」,给编辑鼓励一下!