文 / Margaret Maynard-Reid 和 Sayak Paul,ML GDE
第 3 部分:将 TFLite 模型部署到 Android 应用。
端到端教程:使用 TensorFlow Lite 实现 Selfie2Anime
概览与介绍
在第一部分,我们将主要介绍如何将 TF 1.x 模型转换为 TensorFlow Lite (TFLite),然后将其部署到 Android 应用,从而合理地将自拍人像转换为动漫人物。
通过 Android Studio 中的 ML 模型绑定功能,指导开发者如何使用 TFLite 模型轻松创建移动应用。
请参考此 Colab Notebook,了解模型保存/转换,并点击此 GitHub 链接获取 Android 代码。如果您不熟悉 SavedModel 格式,请参阅 TensorFlow 文档了解详情。
Colab Notebook
https://github.com/margaretmz/selfie2anime-e2e-tutorial/blob/master/ml/Selfie2Anime_Model_Conversion.ipynb
GitHub
https://github.com/margaretmz/selfie2anime-e2e-tutorial/tree/master/android/selfie2animev
📝 TensorFlow 文档
https://tensorflow.google.cn/guide/saved_model
🎵 awesome-tflite
https://github.com/margaretmz/awesome-tflite
U-GAT-IT
我们使用了论文《具有自适应层实例规范化且可用于图像到图像转换的无监督生成注意网络》(Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation) 中提出的生成对抗网络 (GAN) 模型(也称为 U-GAT-IT)。本论文提供了两种生成器:一种可将自拍照转换为动漫风格的图像,另一种可将动漫图像转换为自拍照。在此处,我们仅实现了 Selfie2Anime 模型,因为这与真实场景更相似。
具有自适应层实例规范化且可用于图像到图像转换的无监督生成注意网络
https://arxiv.org/abs/1907.10830V
限制
训练数据不够全面,仅包含女性人脸和动漫面孔,因此 U-GAT-IT 的 selfie2anime 模型似乎仅在女性面孔转换上表现良好。为改进这一模型,我们使用不同种族、性别和年龄的多种人脸图像(如 Fairface 数据集)来重新训练模型。我们将把这一工作留给读者做练习。
转换后的 TFLite 模型已量化,但 Android 上的 GPU 代理尚不支持该模型。因此,您可能会注意到模型推理的延迟会略微延长。
尽管存在这样一些限制,我们仍然认为与所有人分享本端到端流程以及我们面临的挑战非常有价值。希望本教程与示例应用可帮助您实现实际应用。
TFLite 模型
我们来看一下模型保存和转换。主要介绍如何将 TF 1.x 模型转换为 TensorFlow Lite (TFLite),然后将其部署到 Android,从而将自拍图像转换为合理的动漫图像。
以下是分步总结:
运行模型基准测试,确保模型在移动设备上运行良好
使用 TF1 保存模型:从预训练的检查点创建 SavedModel
请注意,这一部分需要在 TensorFlow 1.x 运行时中运行。我们使用的是 TensorFlow 1.14,因为这是编写模型代码使用的版本。
U-GAT-IT 作者提供了两个检查点:一个检查点在 50 个 epoch 后提取 (约 4.6GB),另一个在 100 个 epoch 后提取 (4.7GB)。我们将使用 Kaggle 提供的更轻量级的版本,因为此版本适用于基于移动端的部署。
50 个
https://drive.google.com/file/d/1V6GbSItG3HZKv3quYs7AP0rr1kOCT3QO/view?usp=sharing
100 个 epoch
https://drive.google.com/file/d/19xQK2onIy-3S5W5K-XIh85pAg_RNvBVf/view?usp=sharing
从 Kaggle 下载并提取模型检查点
首先,从最重要的事情开始!我们使用 Kaggle API从 Kaggle 下载检查点。在 kaggle.com 上,转到“我的帐号/API”,点击“新建 API 令牌”,即可触发下载 kaggle.json,其中包含您的 API 凭据。然后,您可以在 Colab 中指定以下内容并设置环境变量:
os.environ['KAGGLE_USERNAME'] = "" # TODO: enter your Kaggle user name here
os.environ['KAGGLE_KEY'] = "" # TODO: enter your Kaggle key here
Kaggle API
https://github.com/Kaggle/kaggle-api
接着我们下载并提取检查点:
kaggle datasets download -d t04glovern/ugatit-selfie2anime-pretrained
unzip -qq /content/ugatit-selfie2anime-pretrained.zip
加载模型检查点并连接张量
此步骤通常因模型而异。此步骤遵循的一般工作流程如下:
生成 SavedModel。
值得注意的是,此工作流程中的第 2 步因模型而异,因此很难预知具体步骤。在此部分,我们将仅专注于需要重点理解的代码部分,对于完整的实现,请查看本教程随附的 Colab Notebook (https://github.com/margaretmz/selfie2anime-e2e-tutorial/blob/master/ml/Selfie2Anime_Model_Conversion.ipynb)。
在我们的示例中,我们可以从主模型类的实例访问输入和输出张量及其详细信息。因此,我们首先会实例化 UGATIT 模型类的实例:
with tf.Graph().as_default(), tf.Session() as sess:
gan = UGATIT(sess, data)
gan.build_model()
load_checkpoint(sess, ckpt_path)
data 指模型配置,具体如此处 (https://github.com/taki0112/UGATIT/blob/master/main.py)所示。获取 UGATIT 类(https://github.com/taki0112/UGATIT/blob/master/UGATIT.py)。此时,我们应该已完成模型的实例化。现在,我们需要通过加载模型的会话将检查点加载到模型中,这就是 load_checkpoint() 方法的操作方式:
def load_checkpoint(sess, ckpt_path):
model_saver = tf.train.Saver(tf.global_variables())
checkpoint = os.path.expanduser(checkpoint)
if tf.gfile.IsDirectory(checkpoint):
checkpoint = tf.train.latest_checkpoint(checkpoint)
latest checkpoint file: {}'.format(checkpoint))
checkpoint)
此时,只需按几次键即可创建 SavedModel。请记住,我们仍处于 session 的上下文之中。
tf.saved_model.simple_save(
sess,
saved_model_dir,
inputs={gan.test_domain_A.name: gan.test_domain_A},
outputs={gan.test_fake_B.name: gan.test_fake_B}
)
正如上方代码所示,我们可以从模型图本身访问输入和输出张量。执行此代码后,SavedModel 文件应该已准备就绪。接着,我们可以继续将此 SavedModel 转换为 TFLite 模型。
准备 TFLite 模型
现在,是时候作出改变,进入 TensorFlow 2.x(2.2.0 或任何更高的 Nightly 版本)。在此部分,我们将使用之前生成的 SavedModel,并将其转换为 TFLite 平面缓冲区,该缓冲区大小约为 10 MB,非常适合在移动应用中使用。然后,我们将使用一些最新的 TensorFlow Lite 工具来准备要部署的模型:
使用 TFLite 模型在 Python 中运行推理,以确保模型在转换后有良好的性能表现。
将元数据添加到 TFLite 模型中,以便通过 Android Studio 的 ML 模型绑定插件轻松地将模型集成到 Android 应用。
使用基准测试工具查看模型在移动设备上的性能。
使用 TF2 将 SavedModel 转换为 TFLite
首先,我们加载 SavedModel 文件并从中创建一个具体函数:
model = tf.saved_model.load(saved_model_path)
concrete_func = model.signatures[ tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
用这种方式执行转换的优点在于,我们可以灵活设置所得 TFLite 模型的输入和输出张量的形状。您可在以下代码片段中看到相关内容:
concrete_func.inputs[0].set_shape([1, 256, 256, 3])
concrete_func.outputs[0].set_shape([1, 256, 256, 3])
建议使用训练模型时使用的相应输入和输出张量的原始形状。在此示例中,此形状为 (1, 256, 256, 3),其中 1 表示批次维度。这是必要设置,因为模型需要数据的形状为:BATCH_SIZE、IMAGE_SHAPE、IMAGE_SHAPE、NB_CHANNELS。为执行实际转换,我们执行运行以下代码:
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_model = converter.convert()
除非我们明确为 converter 指定任何优化选项,否则该模型仍将是浮点模型。您可浏览 TFLite 中各种可用的优化选项。
优化选项
https://www.tensorflow.org/lite/performance/model_optimization
使用 TFLite 模型运行推理
转换之后和部署 .tflite 模型之前,在 Python 中运行推理以确认其能够正常工作,这始终是不错的实操做法。
我们已在数个面孔上测试过该模型,结果表明,与男性面孔相比,该模型在女性面孔上产生的效果要好得多。如果仔细研究训练数据集,我们可以发现所有面孔都是女性面孔,并且因为模型仅在女性面孔上进行过训练,所以存在偏差。
以下为测试结果屏幕:
将元数据添加到 TFLite 模型
现在,我们将元数据添加到 TensorFlow 模型中,以便在 Android 上自动生成模型推理代码。
选项 A:通过命令行
如果要通过命令行使用 Python 脚本添加元数据,请务必先在 conda 或 virtualenv 环境中执行 pip install tflite-support。按如下所示设置文件夹结构:
metadata_writer_for_selfie2anime.py
|-- model_without_metadata
| |--selfie2anime.tflite
|-- model_with_metadata
然后使用 metadata_writer_for_selfie2anime.py 脚本将元数据添加到 selfie2anime.tflite 模型:
python ./metadata_writer_for_selfie2anime.py \
--model_file=./model_without_metadata/selfie2anime.tflite \
--export_directory=model_with_metadata
选项 B:通过 Colab
或者,您也可以改用此 Colab Notebook。另外,也要记得首先执行 $pip install tflite-support。如果您不熟悉如何在命令行中运行 Python 脚本,则使用此选项可能会更简单一些。您所需要做的就是在浏览器中启动 Notebook,上传 selfie2anime.tflite 文件并执行所有单元。
Colab Notebook
https://github.com/margaretmz/selfie2anime-e2e-tutorial/blob/master/ml/add-meta-data-Colab/Add%20metadata%20to%20selfie2anime.ipynb
元数据添加完成
在 model_with_metadata 文件夹下新建两个文件:selfie2anime.tflite 和 selfie2anime.json。此新 selfie2anime.tflite 文件包含模型元数据,在将模型部署到 Android 时,我们可以将元数据用作 Android Studio 中 ML 模型绑定的输入。然后,您可以使用 selfie2anime.json 验证添加到模型中的元数据是否正确。
metadata_writer_for_selfie2anime.py
|-- model_without_metadata
| |--selfie2anime.tflite
|-- model_with_metadata
| |--selfie2anime.tflite
| |--selfie2anime.json
如需详细了解 TFLite 元数据工作原理,请参阅文档。
文档
https://www.tensorflow.org/lite/convert/metadata
Android 上的模型性能基准测试(可选)
这是一个可选步骤,在部署模型前,我们使用 TFLite Android 模型基准测试工具获取了模型在 Android 上的运行时性能表现。如需了解详情,请参考基准测试工具中的说明。
基准测试工具
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/android
以下高度概括了各个步骤:
配置 Android NDK/SDK:存在一些 Android SDK/NDK 先决条件,然后您可以使用 Bazel 构建该工具。
构建基准 APK
bazel build -c opt \
--config=android_arm64 \
//tensorflow/lite/tools/benchmark:benchmark_model
使用 adb(Android 调试桥)安装基准测试工具并将 selfie2anime.tflite 模型部署到 Android 设备:
adb install -r -d -g bazel-bin/tensorflow/lite/tools/benchmark/android/benchmark_model.apk
adb push selfie2anime.tflite /data/local/tmp
运行基准测试工具
adb shell /data/local/tmp/benchmark_model --graph=/data/local/tmp/selfie2anime.tflite --num_threads=4
我们将看到如下的基准测试结果,推理速度有些慢:推理时间(毫秒)(Inference timings in us):初始化 (Init):7135,首次推理 (First inference):7428506,预热(平均值)(Warmup (avg)):7.42851e+06,推理(平均值)(Inference (avg)):7.26313e+06
现在您已经构建好 TensorFlow Lite 模型,接下来让我们一起来学习如何在 Android 上实现该模型。
Android 应用
从这里开始,我们将主要介绍如何将 TF 1.x 模型转换为 TensorFlow Lite (TFLite),然后将其部署到 Android,从而将自拍图像转换为合理的动漫图像。
在上一节,我们已准备好 TFLite 模型,现在即可开始将 selfie2anime.tflite 模型部署到 Android!Android 代码已在 GitHub 上发布,点击此处获取。
此处
https://github.com/margaretmz/selfie2anime-e2e-tutorial/tree/master/android/selfie2anime
用于异步处理模型推理的 Kotlin Coroutine
使用 Kotlin 协程以避免模型推理阻塞界面主线程
0. 下载 Android Studio 4.1 预览版
我们将安装 Android Studio 预览版(4.1 Beta 1),以便使用新的 ML 模型绑定功能导入 .tflite 模型并自动生成代码。您可以直观地浏览 tfllite 模型,也可以在 Android 项目中直接使用生成的类。
点击此处下载 Android Studio 预览版。稳定版本应该能够与预览版本并行运行。请务必将 Gradle 插件更新至 4.1.0-alpha10 或更高版本,否则 ML 绑定菜单可能无法使用。
此处
https://developer.android.google.cn/studio/preview
1. 新建 Android 项目
首先,我们新建一个带有空 Activity(名为 MainActivity )的 Android 项目,该 Activity 中包含一个伴生对象,定义了将用于存储图像(使用 CameraX 捕获)的输出目录。
使用 Jetpack 导航组件在应用内导航。请参考此处教程,详细了解此支持库。
此处
https://medium.com/@margaretmz/android-ui-with-jetpack-nav-component-5ef46d9e0cfc
此示例应用中有 3 个屏幕:
Selfie2animeFragment:处理界面中自拍照和动漫图像的显示
nav_graph.xml 中的导航图将定义三个屏幕的导航,以及 CameraFragment 和 Selfie2animeFragment 之间的数据传递。
2. 设置用于图像捕获的 CameraX
CameraX 是 Jetpack 支持库,可简化相机应用的开发。
Camera1 API 易于使用,但缺少大量功能。与 Camera 1 相比,Camera 2 API 提供了更精细的控制,但其复杂度非常高,仅一个非常基本的示例中就包含了近 1000 行代码。
另一方面,CameraX 的设置要简单得多,所需代码仅为十分之一。此外,CameraX 具有生命周期感知能力,因此您无需编写额外代码来处理生命周期。
以下是为此示例应用设置 CameraX 的步骤:
捕获图像并将其转换为 Bitmap
捕获图像后,我们会将其转换为位图,再将 Bitmap 其传递给 TFLite 模型以进行推理。导航至新屏幕 Selfie2animeFragment.kt,在此屏幕中会同时显示原始自拍图像和动漫图像。
3. 导入 TensorFlow Lite 模型
现在,界面代码已完成。是时候导入 TensorFlow Lite 进行推理了。ML 模型绑定可轻松解决此问题。在 Android Studio 中,转至 文件 > 新建 > 其他 > TensorFlow Lite 模型。
此外,请务必勾选“将 TensorFlow Lite GPU 依赖项自动添加到 Gradle”,因为 selfie2anime 模型运行速度非常慢,我们需要启用 GPU 代理。
此次导入将执行两个操作:
自动在 app/build/generated/ml_source_out/debug/com/tflite/selfie2anime/ml 文件夹下生成一个名为 Selfie2anime.java 的 Java 类,此类将处理模型加载、图像预处理和后处理等各种任务,并运行模型推理以将自拍图像转换为动漫图像。
导入完成后,我们将看到 selfie2anime.tflite 显示模型元数据信息以及 Kotlin 和 Java 中的代码片段,而我可以复制并粘贴这些代码片段来使用模型:
4. 组合所有要素
现在,我们已设置了界面导航,配置好用于图像捕获的 CameraX,并且已导入了 selfie2anime.tflite 模型,可以将所有要素组合到一起了!首先,我们使用位于 imageCaptue?.takePicture() 下 CameraFragment.kt 中的 CameraX 捕获自拍照,然后 onCaptureSuccess() 中将返回一个 ImageProxy。我们将 ImageProxy 转换为位图,然后将其保存到之前在 MainActivity 中定义的输出目录下。
通过 JetPack 导航组件,我们可以轻松导航到 Selfie2animeFragment 并将图像目录位置以字符串参数形式传递。
接着在 Selfie2animeFragment.kt 中检索存储自拍照的文件目录字符串,创建图像文件,然后将其转换为位图,以用作 tflite 模型的输入。
在自拍图像上运行模型推理并创建动漫图像。我们在界面中同时显示自拍图像和动漫图像。
注意:推理需要相当长的时间,因此我们使用 Kotlin 协程来防止模型推理阻塞界面主线程。显示 progressBar,直到模型推理完成。
所有要素组合完成后,我们将看到以下内容:
至此,本教程结束。希望您阅读愉快,并使用 TensorFlow Lite 将所学知识运用到实际应用中。无论您已使用此处所学知识创建了什么炫酷示例,都别忘了把它们添加到 awesome-tflite 中!
致谢
本教程的创建离不开 ML GDEs 与 TensorFlow Lite 团队的大力协作。请查看 awesome-tflite 存储库,获取各种应用创意。此外,您还能从该存储库中找到大量 TensorFlow Lite 模型、示例、教程、工具和学习资源。
感谢 Khanh LeViet 和 Lu Wang(TensorFlow Lite 团队)、Hoi Lam (Android ML) 和 Soonson Kwon(ML GDEs,Google Developer Experts Program)与我们合作并持续提供支持。
awesome-tflite
https://github.com/margaretmz/awesome-tflite
了解更多请点击 “阅读原文” 访问 GitHub。