文 / Kangyi Zhang,Sandeep Gupta 和 Brijesh Krishnaswami
TensorFlow.js 是一个开源代码库,开发者可以通过 JavaScript 语言定义、训练和运行机器学习模型。这让大多数的 JavaScript 开发者也能参与构建和部署机器学习模型,并由此产生了很多新的机器学习用例。如 TensorFlow.js 可以在所有主流浏览器中运行,服务端有 Node.js,还有最近 微信小程序插件 和 React Native 开始实现在混合移动应用中的机器学习相关操作,开发者无需离开 JS 生态。现在,我们很高兴为 Node.js 开发者提供一种新方法,可以无需进行模型转换,轻松高效地部署预训练的 TensorFlow SavedModel 。
TensorFlow.js 的主要优势之一是 JavaScript 开发者可以轻松地部署预训练的 TensorFlow 模型进行推理。TensorFlow.js 提供了转换工具 tfjs-converter ,可将 TensorFlow SavedModel、TFHub 模型或 Keras 模型转换为 JavaScript 兼容格式。但是,转换工具需要 JavaScript 开发者安装 TensorFlow 的 Python 工具包并学习如何使用它。此外,转换工具不支持全部的 TensorFlow 算子(支持的算子参见此文),因此,如果模型包含不支持的算子,则无法使用此工具。
在 Node.js 中执行原生模型
我们很高兴宣布现在可以在 Node.js 中执行原生 TensorFlow SavedModel。现在,您可以把预训练的 TensorFlow 模型存为 SavedModel 格式,并通过 @tensorflow/tfjs-node 或 tfjs-node-gpu 包将模型加载到 Node.js 进行推理,且无需使用转换工具 tfjs-converter。
TensorFlow SavedModel 通常含有一个或几个命名函数,称为 SignatureDef。预训练的TensorFlow SavedModel 可以通过一行代码在 JavaScript 中加载模型的 SignatureDef,随后该模型便可用于推理。
const model = await tf.node.loadSavedModel(path, [tag], signatureKey);
const output = model.predict(input);
也可以将多个输入以数组或图的形式提供给模型:
const model1 = await tf.node.loadSavedModel(path1, [tag], signatureKey);
const outputArray = model1.predict([inputTensor1, inputTensor2]);
const model2 = await tf.node.loadSavedModel(path2, [tag], signatureKey);
const outputMap = model2.predict({input1: inputTensor1, input2:inputTensor2});
如需查看 TensorFlow SavedModel 的详细信息,查找模型标签和签名信息(又称为 MetaGraph),可以通过一个 JavaScript helper API 对其进行解析,类似于 TensorFlow SavedModel 客户端工具:
const modelInfo = await tf.node.getMetaGraphsFromSavedModel(path);
此项新功能可在 1.3.2 或更高版本的 @tensorflow/tfjs-node 包中使用,同时支持 CPU 和 GPU。它支持在 TensorFlow Python 1.x 和 2.0 版本中训练和导出的 TensorFlow SavedModel。由此带来的好处除了无需进行任何转换,原生执行 TensorFlow SavedModel 意味着您可以在模型中使用 TensorFlow.js 尚未支持的算子。这要通过将 SavedModel 作为 TensorFlow 会话加载到 C++ 中进行绑定予以实现。
除了可用性上的优点,性能上的表现同样有亮点。在下图的性能基准测试中(使用 MobileNetV2 模型,横轴为推理用时),可以看到直接在 Node.js 中执行 SavedModel,CPU 和 GPU 的推理用时均有所降低。
您可以到 @tensorflow/tfjs-examples 仓库查看我们的示例 。欢迎加入我们的 讨论组 并分享您的反馈!
如果您想详细了解 本文提及 的相关内容,请参阅以下文档。这些文档深入探讨了这篇文章中提及的许多主题:
TensorFlow.js (阅读原文直接跳转)
http://tensorflow.google.cn/js
微信小程序插件
https://github.com/tensorflow/tfjs-wechat
React Native
https://github.com/tensorflow/tfjs/tree/master/tfjs-react-native
SavedModel
https://tensorflow.google.cn/guide/saved_model
tfjs-converter
https://github.com/tensorflow/tfjs/tree/master/tfjs-converter
此文
https://js.tensorflow.org/api/latest/#Operations
@tensorflow/tfjs-node
https://www.npmjs.com/package/@tensorflow/tfjs-node
tfjs-node-gpu
https://www.npmjs.com/package/@tensorflow/tfjs-node-gpu
SignatureDef
https://tensorflow.google.cn/guide/saved_model#identifying_a_signature_to_export
客户端工具
https://tensorflow.google.cn/guide/saved_model#show_command
示例
https://github.com/tensorflow/tfjs-examples/tree/master/firebase-object-detection-node
讨论组
https://groups.google.com/a/tensorflow.org/forum/#!forum/tfjs