社区分享 | “漫画脸”是这样诞生的:从模型到部署全解析

2020 年 8 月 7 日 TensorFlow

文 / Margaret Maynard-Reid 和 Sayak Paul,ML GDE


本教程分为三个部分,您可以按端到端部署教程顺序进行学习,也可以跳到最感兴趣或与您相关度最高的部分:
  • 第 1 部分:本项目(使用 TensorFlow Lite 完成 Selfie2Anime 项目)概览和介绍。
  • 第 2 部分:如何创建  SavedModel  并将其转换为 TFLite 模型。模型保存步骤在 TensorFlow 1.14 运行环境执行,因为这是编写模型代码使用的版本,而相同的方法也可以应用于大多数用 TensorFlow 1.x 编写的模型。模型转换步骤在 TensorFlow 2.x 运行环境利用 TFLiteConverter 的最新功能执行。
  • 第 3 部分:将 TFLite 模型部署到 Android 应用。


端到端教程:使用 TensorFlow Lite 实现 Selfie2Anime



概览与介绍

在第一部分,我们将主要介绍如何将 TF 1.x 模型转换为 TensorFlow Lite (TFLite),然后将其部署到 Android 应用,从而合理地将自拍人像转换为动漫人物。


以下为本端到端教程的目标:
  • 为开发者提供参考,助其了解如何使用最新 (v2) 转换器新功能(例如 MLIR 转换器、支持性更好的算子和改进的内核等)将在 TensorFlow 1.x 内编写的模型转换为 TFLite 变体。
  • 助开发者了解如何使用 TFLite 工具,例如 Android Benchmark 工具、模型元数据和 Codegen。
  • 通过 Android Studio 中的 ML 模型绑定功能,指导开发者如何使用 TFLite 模型轻松创建移动应用。


请参考此 Colab Notebook,了解模型保存/转换,并点击此 GitHub 链接获取 Android 代码。如果您不熟悉 SavedModel 格式,请参阅 TensorFlow 文档了解详情。


Colab Notebook
https://github.com/margaretmz/selfie2anime-e2e-tutorial/blob/master/ml/Selfie2Anime_Model_Conversion.ipynb


GitHub
https://github.com/margaretmz/selfie2anime-e2e-tutorial/tree/master/android/selfie2animev


📝 TensorFlow 文档
https://tensorflow.google.cn/guide/saved_model

🎵 awesome-tflite
https://github.com/margaretmz/awesome-tflite


U-GAT-IT

我们使用了论文《具有自适应层实例规范化且可用于图像到图像转换的无监督生成注意网络(Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation) 中提出的生成对抗网络 (GAN) 模型(也称为 U-GAT-IT)。本论文提供了两种生成器:一种可将自拍照转换为动漫风格的图像,另一种可将动漫图像转换为自拍照。在此处,我们仅实现了 Selfie2Anime 模型,因为这与真实场景更相似。

  • 具有自适应层实例规范化且可用于图像到图像转换的无监督生成注意网络
    https://arxiv.org/abs/1907.10830V


限制

训练数据不够全面,仅包含女性人脸和动漫面孔,因此 U-GAT-IT 的 selfie2anime 模型似乎仅在女性面孔转换上表现良好。为改进这一模型,我们使用不同种族、性别和年龄的多种人脸图像(如 Fairface 数据集)来重新训练模型。我们将把这一工作留给读者做练习。


转换后的 TFLite 模型已量化,但 Android 上的 GPU 代理尚不支持该模型。因此,您可能会注意到模型推理的延迟会略微延长。


尽管存在这样一些限制,我们仍然认为与所有人分享本端到端流程以及我们面临的挑战非常有价值。希望本教程与示例应用可帮助您实现实际应用。



TFLite 模型

我们来看一下模型保存和转换。主要介绍如何将 TF 1.x 模型转换为 TensorFlow Lite (TFLite),然后将其部署到 Android,从而将自拍图像转换为合理的动漫图像。


以下是分步总结:

  • 从预训练的 U-GAT-IT 模型检查点中生成 SavedModel
  • 使用最新的 TFLiteConverter 转换 SavedModel
  • 使用转换后的模型在 Python 中运行推理
  • 添加元数据以轻松与移动应用集成
  • 运行模型基准测试,确保模型在移动设备上运行良好


使用 TF1 保存模型:从预训练的检查点创建 SavedModel

请注意,这一部分需要在 TensorFlow 1.x 运行时中运行。我们使用的是 TensorFlow 1.14,因为这是编写模型代码使用的版本。


U-GAT-IT 作者提供了两个检查点:一个检查点在 50 个 epoch 后提取 (约 4.6GB),另一个在 100 个 epoch 后提取 (4.7GB)。我们将使用 Kaggle 提供的更轻量级的版本,因为此版本适用于基于移动端的部署。

  • 50 个 
    https://drive.google.com/file/d/1V6GbSItG3HZKv3quYs7AP0rr1kOCT3QO/view?usp=sharing

  • 100 个 epoch
    https://drive.google.com/file/d/19xQK2onIy-3S5W5K-XIh85pAg_RNvBVf/view?usp=sharing


从 Kaggle 下载并提取模型检查点

首先,从最重要的事情开始!我们使用 Kaggle API从 Kaggle 下载检查点。在 kaggle.com 上,转到“我的帐号/API”,点击“新建 API 令牌”,即可触发下载 kaggle.json,其中包含您的 API 凭据。然后,您可以在 Colab 中指定以下内容并设置环境变量:

os.environ['KAGGLE_USERNAME'] = "" # TODO: enter your Kaggle user name hereos.environ['KAGGLE_KEY'] = "" TODO: enter your Kaggle key here
  • Kaggle API
    https://github.com/Kaggle/kaggle-api


接着我们下载并提取检查点:

$ kaggle datasets download -d t04glovern/ugatit-selfie2anime-pretrained$ unzip -qq /content/ugatit-selfie2anime-pretrained.zip


加载模型检查点并连接张量

此步骤通常因模型而异。此步骤遵循的一般工作流程如下:

  1. 定义模型的输入和输出张量。
  2. 实例化模型并连接输入和输出张量,从而构建计算图。
  3. 将预训练的检查点加载到模型的计算图中。
  4. 生成 SavedModel。


值得注意的是,此工作流程中的第 2 步因模型而异,因此很难预知具体步骤。在此部分,我们将仅专注于需要重点理解的代码部分,对于完整的实现,请查看本教程随附的 Colab Notebook (https://github.com/margaretmz/selfie2anime-e2e-tutorial/blob/master/ml/Selfie2Anime_Model_Conversion.ipynb)


在我们的示例中,我们可以从主模型类的实例访问输入和输出张量及其详细信息。因此,我们首先会实例化 UGATIT 模型类的实例:

with tf.Graph().as_default(), tf.Session() as sess:   gan = UGATIT(sess, data)   gan.build_model()   load_checkpoint(sess, ckpt_path)


data 指模型配置,具体如此处 (https://github.com/taki0112/UGATIT/blob/master/main.py)所示。获取 UGATIT 类(https://github.com/taki0112/UGATIT/blob/master/UGATIT.py)。此时,我们应该已完成模型的实例化。现在,我们需要通过加载模型的会话将检查点加载到模型中,这就是 load_checkpoint() 方法的操作方式:

def load_checkpoint(sess, ckpt_path):   model_saver = tf.train.Saver(tf.global_variables())   checkpoint = os.path.expanduser(checkpoint)   if tf.gfile.IsDirectory(checkpoint):       checkpoint = tf.train.latest_checkpoint(checkpoint)       tf.logging.info('loading latest checkpoint file: {}'.format(checkpoint))   model_saver.restore(sess, checkpoint)


此时,只需按几次键即可创建 SavedModel。请记住,我们仍处于 session 的上下文之中。

tf.saved_model.simple_save(       sess,       saved_model_dir,       inputs={gan.test_domain_A.name: gan.test_domain_A},       outputs={gan.test_fake_B.name: gan.test_fake_B})


正如上方代码所示,我们可以从模型图本身访问输入和输出张量。执行此代码后,SavedModel 文件应该已准备就绪。接着,我们可以继续将此 SavedModel 转换为 TFLite 模型。


准备 TFLite 模型

现在,是时候作出改变,进入 TensorFlow 2.x(2.2.0 或任何更高的 Nightly 版本)。在此部分,我们将使用之前生成的 SavedModel,并将其转换为 TFLite 平面缓冲区,该缓冲区大小约为 10 MB,非常适合在移动应用中使用。然后,我们将使用一些最新的 TensorFlow Lite 工具来准备要部署的模型:

  • 使用 TFLite 模型在 Python 中运行推理,以确保模型在转换后有良好的性能表现。

  • 将元数据添加到 TFLite 模型中,以便通过 Android Studio 的 ML 模型绑定插件轻松地将模型集成到 Android 应用。

  • 使用基准测试工具查看模型在移动设备上的性能。


使用 TF2 将 SavedModel 转换为 TFLite

首先,我们加载 SavedModel 文件并从中创建一个具体函数

model = tf.saved_model.load(saved_model_path)concrete_func = model.signatures[ tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]


用这种方式执行转换的优点在于,我们可以灵活设置所得 TFLite 模型的输入和输出张量的形状。您可在以下代码片段中看到相关内容:

concrete_func.inputs[0].set_shape([1, 256, 256, 3])concrete_func.outputs[0].set_shape([1, 256, 256, 3])


建议使用训练模型时使用的相应输入和输出张量的原始形状。在此示例中,此形状为 (1, 256, 256, 3),其中 1 表示批次维度。这是必要设置,因为模型需要数据的形状为:BATCH_SIZE、IMAGE_SHAPE、IMAGE_SHAPE、NB_CHANNELS。为执行实际转换,我们执行运行以下代码:

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]tflite_model = converter.convert()


除非我们明确为 converter 指定任何优化选项,否则该模型仍将是浮点模型。您可浏览 TFLite 中各种可用的优化选项

  • 优化选项
    https://www.tensorflow.org/lite/performance/model_optimization


使用 TFLite 模型运行推理

转换之后和部署 .tflite 模型之前,在 Python 中运行推理以确认其能够正常工作,这始终是不错的实操做法。


我们已在数个面孔上测试过该模型,结果表明,与男性面孔相比,该模型在女性面孔上产生的效果要好得多。如果仔细研究训练数据集,我们可以发现所有面孔都是女性面孔,并且因为模型仅在女性面孔上进行过训练,所以存在偏差。


以下为测试结果屏幕:


将元数据添加到 TFLite 模型

现在,我们将元数据添加到 TensorFlow 模型中,以便在 Android 上自动生成模型推理代码。


选项 A:通过命令行

如果要通过命令行使用 Python 脚本添加元数据,请务必先在 conda 或 virtualenv 环境中执行 pip install tflite-support。按如下所示设置文件夹结构:

metadata_writer_for_selfie2anime.py|-- model_without_metadata|   |--selfie2anime.tflite|-- model_with_metadata


然后使用 metadata_writer_for_selfie2anime.py 脚本将元数据添加到 selfie2anime.tflite 模型:

python ./metadata_writer_for_selfie2anime.py \--model_file=./model_without_metadata/selfie2anime.tflite \--export_directory=model_with_metadata


选项 B:通过 Colab

或者,您也可以改用此 Colab Notebook。另外,也要记得首先执行 $pip install tflite-support。如果您不熟悉如何在命令行中运行 Python 脚本,则使用此选项可能会更简单一些。您所需要做的就是在浏览器中启动 Notebook,上传 selfie2anime.tflite 文件并执行所有单元。

  • Colab Notebook
    https://github.com/margaretmz/selfie2anime-e2e-tutorial/blob/master/ml/add-meta-data-Colab/Add%20metadata%20to%20selfie2anime.ipynb


元数据添加完成

model_with_metadata 文件夹下新建两个文件:selfie2anime.tfliteselfie2anime.json。此新 selfie2anime.tflite 文件包含模型元数据,在将模型部署到 Android 时,我们可以将元数据用作 Android Studio 中 ML 模型绑定的输入。然后,您可以使用 selfie2anime.json 验证添加到模型中的元数据是否正确。

metadata_writer_for_selfie2anime.py|-- model_without_metadata|   |--selfie2anime.tflite|-- model_with_metadata|   |--selfie2anime.tflite|   |--selfie2anime.json


如需详细了解 TFLite 元数据工作原理,请参阅文档

  • 文档
    https://www.tensorflow.org/lite/convert/metadata


Android 上的模型性能基准测试(可选)

这是一个可选步骤,在部署模型前,我们使用 TFLite Android 模型基准测试工具获取了模型在 Android 上的运行时性能表现。如需了解详情,请参考基准测试工具中的说明。

  • 基准测试工具
    https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/tools/benchmark/android


以下高度概括了各个步骤:

  • 配置 Android NDK/SDK:存在一些 Android SDK/NDK 先决条件,然后您可以使用 Bazel 构建该工具。

  • 构建基准 APK

bazel build -c opt \    --config=android_arm64 \    //tensorflow/lite/tools/benchmark:benchmark_model
  • 使用 adb(Android 调试桥)安装基准测试工具并将 selfie2anime.tflite 模型部署到 Android 设备:

adb install -r -d -g bazel-bin/tensorflow/lite/tools/benchmark/android/benchmark_model.apkadb push selfie2anime.tflite /data/local/tmp
  • 运行基准测试工具

adb shell /data/local/tmp/benchmark_model --graph=/data/local/tmp/selfie2anime.tflite --num_threads=4


我们将看到如下的基准测试结果,推理速度有些慢:推理时间(毫秒)(Inference timings in us):初始化 (Init):7135,首次推理 (First inference):7428506,预热(平均值)(Warmup (avg)):7.42851e+06,推理(平均值)(Inference (avg)):7.26313e+06


现在您已经构建好 TensorFlow Lite 模型,接下来让我们一起来学习如何在 Android 上实现该模型。



Android 应用

从这里开始,我们将主要介绍如何将 TF 1.x 模型转换为 TensorFlow Lite (TFLite),然后将其部署到 Android,从而将自拍图像转换为合理的动漫图像。


在上一节,我们已准备好 TFLite 模型,现在即可开始将 selfie2anime.tflite 模型部署到 Android!Android 代码已在 GitHub 上发布,点击此处获取。

  • 此处
    https://github.com/margaretmz/selfie2anime-e2e-tutorial/tree/master/android/selfie2anime


以下是 Android 应用的关键功能:
  • 用于界面导航的 Jetpack 导航组件
  • 用于图像捕获的 CameraX
  • 用于导入 tflite 模型的 ML 模型绑定
  • 用于异步处理模型推理的 Kotlin Coroutine


以下是在 Android 上实现 TFLite 模型的分步说明:
0. 下载 Android Studio 4.1 预览版
1. 创建新 Android 项目并设置界面导航
2. 设置用于图像捕获的 CameraX API
3. 使用 ML 模型绑定导入 selfie2anime.tflite 模型。
4. 组合所有要素:
  • 模型输入:使用 CameraX 捕获自拍图像
  • 在自拍图像上运行推理并创建动漫图像
  • 在界面中同时显示自拍图像和动漫图像
  • 使用 Kotlin 协程以避免模型推理阻塞界面主线程


0. 下载 Android Studio 4.1 预览版

我们将安装 Android Studio 预览版(4.1 Beta 1),以便使用新的 ML 模型绑定功能导入 .tflite 模型并自动生成代码。您可以直观地浏览 tfllite 模型,也可以在 Android 项目中直接使用生成的类。


点击此处下载 Android Studio 预览版。稳定版本应该能够与预览版本并行运行。请务必将 Gradle 插件更新至 4.1.0-alpha10 或更高版本,否则 ML 绑定菜单可能无法使用。

  • 此处
    https://developer.android.google.cn/studio/preview


1. 新建 Android 项目

首先,我们新建一个带有空 Activity(名为 MainActivity )的 Android 项目,该 Activity 中包含一个伴生对象,定义了将用于存储图像(使用 CameraX 捕获)的输出目录。


使用 Jetpack 导航组件在应用内导航。请参考此处教程,详细了解此支持库。

  • 此处
    https://medium.com/@margaretmz/android-ui-with-jetpack-nav-component-5ef46d9e0cfc


此示例应用中有 3 个屏幕:

  • PermissionsFragment:处理相机权限检查
  • CameraFragment:处理相机设置和图像捕获
  • Selfie2animeFragment:处理界面中自拍照和动漫图像的显示


nav_graph.xml 中的导航图将定义三个屏幕的导航,以及 CameraFragment 和 Selfie2animeFragment 之间的数据传递。


2. 设置用于图像捕获的 CameraX

CameraX 是 Jetpack 支持库,可简化相机应用的开发。


Camera1 API 易于使用,但缺少大量功能。与 Camera 1 相比,Camera 2 API 提供了更精细的控制,但其复杂度非常高,仅一个非常基本的示例中就包含了近 1000 行代码。


另一方面,CameraX 的设置要简单得多,所需代码仅为十分之一。此外,CameraX 具有生命周期感知能力,因此您无需编写额外代码来处理生命周期。


以下是为此示例应用设置 CameraX 的步骤:

  • 更新 build.gradle 依赖项
  • 使用 CameraFragment.kt 保留 CameraX 代码
  • 请求相机权限
  • 更新 AndroidManifest.ml
  • 检查 MainActivity.kt 中的权限
  • 使用 CameraX Preview 类实现取景器
  • 实现图像捕获
  • 捕获图像并将其转换为 Bitmap


捕获图像后,我们会将其转换为位图,再将 Bitmap 其传递给 TFLite 模型以进行推理。导航至新屏幕 Selfie2animeFragment.kt,在此屏幕中会同时显示原始自拍图像和动漫图像。


3. 导入 TensorFlow Lite 模型

现在,界面代码已完成。是时候导入 TensorFlow Lite 进行推理了。ML 模型绑定可轻松解决此问题。在 Android Studio 中,转至 文件 > 新建 > 其他 > TensorFlow Lite 模型


  • 指定 selfie2anime.tflite 文件位置。
  • 默认情况下会勾选“将构建功能和必要依赖项自动添加到 Gradle”。
  • 此外,请务必勾选“将 TensorFlow Lite GPU 依赖项自动添加到 Gradle”,因为 selfie2anime 模型运行速度非常慢,我们需要启用 GPU 代理。



此次导入将执行两个操作:

  • 自动创建 ml 文件夹,并将模型文件 selfie2anime.tflite 文件放入此文件夹下。
  • 自动在 app/build/generated/ml_source_out/debug/com/tflite/selfie2anime/ml 文件夹下生成一个名为 Selfie2anime.java 的 Java 类,此类将处理模型加载、图像预处理和后处理等各种任务,并运行模型推理以将自拍图像转换为动漫图像。


导入完成后,我们将看到 selfie2anime.tflite 显示模型元数据信息以及 Kotlin 和 Java 中的代码片段,而我可以复制并粘贴这些代码片段来使用模型:


4. 组合所有要素

现在,我们已设置了界面导航,配置好用于图像捕获的 CameraX,并且已导入了 selfie2anime.tflite 模型,可以将所有要素组合到一起了!首先,我们使用位于 imageCaptue?.takePicture() 下 CameraFragment.kt 中的 CameraX 捕获自拍照,然后 onCaptureSuccess() 中将返回一个 ImageProxy。我们将 ImageProxy 转换为位图,然后将其保存到之前在 MainActivity 中定义的输出目录下。


通过 JetPack 导航组件,我们可以轻松导航到 Selfie2animeFragment 并将图像目录位置以字符串参数形式传递。


接着在 Selfie2animeFragment.kt 中检索存储自拍照的文件目录字符串,创建图像文件,然后将其转换为位图,以用作 tflite 模型的输入。


在自拍图像上运行模型推理并创建动漫图像。我们在界面中同时显示自拍图像和动漫图像。


注意:推理需要相当长的时间,因此我们使用 Kotlin 协程来防止模型推理阻塞界面主线程。显示 progressBar,直到模型推理完成。


所有要素组合完成后,我们将看到以下内容:


至此,本教程结束。希望您阅读愉快,并使用 TensorFlow Lite 将所学知识运用到实际应用中。无论您已使用此处所学知识创建了什么炫酷示例,都别忘了把它们添加到 awesome-tflite 中!



致谢

本教程的创建离不开 ML GDEs 与 TensorFlow Lite 团队的大力协作。请查看 awesome-tflite 存储库,获取各种应用创意。此外,您还能从该存储库中找到大量 TensorFlow Lite 模型、示例、教程、工具和学习资源。


感谢 Khanh LeViet 和 Lu Wang(TensorFlow Lite 团队)、Hoi Lam (Android ML) 和 Soonson Kwon(ML GDEs,Google Developer Experts Program)与我们合作并持续提供支持。

  • awesome-tflite
    https://github.com/margaretmz/awesome-tflite



了解更多请点击 “阅读原文” 访问 GitHub。

登录查看更多
13

相关内容

Google发布的第二代深度学习系统TensorFlow
Transformer模型-深度学习自然语言处理,17页ppt
专知会员服务
103+阅读 · 2020年8月30日
【Amazon】使用预先训练的Transformer模型进行数据增强
专知会员服务
56+阅读 · 2020年3月6日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
69+阅读 · 2020年1月17日
BERT进展2019四篇必读论文
专知会员服务
67+阅读 · 2020年1月2日
【电子书】Flutter实战305页PDF免费下载
专知会员服务
22+阅读 · 2019年11月7日
TensorFlow 2.0 学习资源汇总
专知会员服务
66+阅读 · 2019年10月9日
社区分享|如何让模型在生产环境上推理得更快
2019热门开源机器学习项目汇总
专知
9+阅读 · 2020年1月3日
利用 AutoML 的功能构建和部署 TensorFlow.js 模型
TensorFlow
6+阅读 · 2019年12月16日
【GitHub】BERT模型从训练到部署全流程
专知
34+阅读 · 2019年6月28日
用Now轻松部署无服务器Node应用程序
前端之巅
16+阅读 · 2019年6月19日
5T技术资料免费资料分享,欢迎大家加入社区获取
大数据和云计算技术
4+阅读 · 2018年1月11日
Weight Poisoning Attacks on Pre-trained Models
Arxiv
5+阅读 · 2020年4月14日
Arxiv
7+阅读 · 2018年12月10日
Arxiv
4+阅读 · 2018年10月31日
Arxiv
5+阅读 · 2018年4月13日
Arxiv
8+阅读 · 2018年1月25日
VIP会员
相关VIP内容
Transformer模型-深度学习自然语言处理,17页ppt
专知会员服务
103+阅读 · 2020年8月30日
【Amazon】使用预先训练的Transformer模型进行数据增强
专知会员服务
56+阅读 · 2020年3月6日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
69+阅读 · 2020年1月17日
BERT进展2019四篇必读论文
专知会员服务
67+阅读 · 2020年1月2日
【电子书】Flutter实战305页PDF免费下载
专知会员服务
22+阅读 · 2019年11月7日
TensorFlow 2.0 学习资源汇总
专知会员服务
66+阅读 · 2019年10月9日
相关资讯
相关论文
Weight Poisoning Attacks on Pre-trained Models
Arxiv
5+阅读 · 2020年4月14日
Arxiv
7+阅读 · 2018年12月10日
Arxiv
4+阅读 · 2018年10月31日
Arxiv
5+阅读 · 2018年4月13日
Arxiv
8+阅读 · 2018年1月25日
Top
微信扫码咨询专知VIP会员