学习教程 | 使用 TensorFlow Lite 在 Android App 中生成超分图片

2020 年 11 月 24 日 TensorFlow

文 / 魏巍, TensorFlow 技术推广工程师


从一张低分辨率的图片生成一张对应的高分辨率图片的任务通常被称为单图超分 (Single Image Super Resolution - SISR)。尽管可以使用传统的插值方法(如双线性插值和双三次插值)来完成这个任务,但是产生的图片质量却经常差强人意。深度学习,尤其是对抗生成网络 GAN,已经被成功应用在超分任务上,比如 SRGANESRGAN 都可以生成比较真实的超分图片。那么在本文里,我们将介绍一下如何使用TensorFlow Hub上的一个预训练的 ESRGAN 模型来在一个安卓 app 中生成超分图片。最终的 app 效果如下图,我们也已经将完整代码开源给大家参考。

  • SRGAN
    https://arxiv.org/abs/1609.04802

  • ESRGAN
    https://arxiv.org/abs/1809.00219

  • 完整代码
    https://github.com/tensorflow/examples/tree/master/lite/examples/super_resolution


首先,我们可以很方便的从 TFHub 上加载 ESRGAN 模型,然后很容易的将其转化为一个 TFLite 模型。注意在这里我们使用了动态范围量化 (dynamic range quantization),并将输入图片的尺寸固定在50x50像素(我们已经将转化后的模型上传到 TFHub 上了):

model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape([1, 50, 50, 3])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save the TF Lite model.
with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
f.write(tflite_model)

esrgan_model_path = './ESRGAN.tflite'
  • TFHub
    https://hub.tensorflow.google.cn/

  • TFHub(转化后模型)
    https://hub.tensorflow.google.cn/captain-pool/lite-model/esrgan-tf2/1


现在 TFLite 已经支持动态大小的输入,所以你也可以在模型转化的时候不指定输入图片的大小,而在运行的时候动态指定。如果你想使用动态输入大小,请参考这个例子

  • 例子
    https://github.com/tensorflow/tensorflow/blob/c58c88b23122576fc99ecde988aab6041593809b/tensorflow/lite/python/lite_test.py#L529-L560


模型转化完之后,我们可以很快验证 ESRGAN 生成的超分图片质量确实比双三次插值要好很多。如果你想更多的了解 ESRGAN 模型,我们还有另外一个教程可供参考:

lr = cv2.imread(test_img_path)
lr = cv2.cvtColor(lr, cv2.COLOR_BGR2RGB)
lr = tf.expand_dims(lr, axis=0)
lr = tf.cast(lr, tf.float32)

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Run the model
interpreter.set_tensor(input_details[0]['index'], lr)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
sr = tf.squeeze(output_data, axis=0)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
  • 教程
    https://tensorflow.google.cn/hub/tutorials/image_enhancing

LR: 输入的低分辨率图片,该图从 DIV2K 数据集中的一张蝴蝶图片中切割出来. ESRGAN (x4): ESRGAN 模型生成的超分图片,单边分辨率提升4倍. Bicubic: 双三次插值生成图片. 在这里大家可以很容易看出来,双三次插值生成的图片要比 ESRGAN 模型生成的超分图片模糊很多


你可能已经知道,TensorFlow Lite 是 TensorFlow 用于在端侧运行的官方框架,目前全球已有超过40亿台设备在运行 TFLite,它可以运行在安卓,iOS,基于 Linux 的 IoT 设备以及微处理器上。你可以使用 Java, C/C++ 或其他编程语言来运行 TFLite。在这篇文章中,我们将使用 TFLite C API,因为有许多的开发者表示希望我们能提供这样一个范例。

  • DIV2K 
    https://data.vision.ee.ethz.ch/cvl/DIV2K/

  • Java, C/C++ 
    https://tensorflow.google.cn/lite/guide/android

  • TFLite C API
    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/c_api.h


我们在预先编译好的 AAR 文件中包含了 TFLite C API需要的头文件和库 (包括核心库和 GPU 库)。为了正确的设置好 Android 项目,我们首先需要下载两个 JAR 文件并将相应的头文件和库抽取出来。我们可以在一个 download.gradle 文件中定义这些任务,然后将这些任务导入 build.gradle。下面我们先定义下载 TFLite JAR 文件的两个任务:

task downloadTFLiteAARFile() {
download {
src "https://bintray.com/google/tensorflow/download_file?file_path=org%2Ftensorflow%2Ftensorflow-lite%2F2.3.0%2Ftensorflow-lite-2.3.0.aar"
dest "${project.rootDir}/libraries/tensorflow-lite-2.3.0.aar"
overwrite false
retries 5
}
}


task downloadTFLiteGPUDelegateAARFile() {
download {
src "https://bintray.com/google/tensorflow/download_file?file_path=org%2Ftensorflow%2Ftensorflow-lite-gpu%2F2.3.0%2Ftensorflow-lite-gpu-2.3.0.aar"
dest "${project.rootDir}/libraries/tensorflow-lite-gpu-2.3.0.aar"
overwrite false
retries 5
}
}
  • AAR 文件
    https://tensorflow.google.cn/lite/guide/android#use_tflite_c_api


然后我们定义另一个任务来讲头文件和库解压然后放到正确的位置:

task fetchTFLiteLibs() {
copy {
from zipTree("${project.rootDir}/libraries/tensorflow-lite-2.3.0.aar")
into "${project.rootDir}/libraries/tensorflowlite/"
include "headers/tensorflow/lite/c/*h"
include "headers/tensorflow/lite/*h"
include "jni/**/libtensorflowlite_jni.so"
}
copy {
from zipTree("${project.rootDir}/libraries/tensorflow-lite-gpu-2.3.0.aar")
into "${project.rootDir}/libraries/tensorflowlite-gpu/"
include "headers/tensorflow/lite/delegates/gpu/*h"
include "jni/**/libtensorflowlite_gpu_jni.so"
}


因为我们是用安卓 NDK 来编译这个 app,我们需要让 Android Studio 知道如何处理对应的原生文件。我们在 CMakeList.txt 文件中这样写:

set(TFLITE_LIBPATH "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libraries/tensorflowlite/jni")
set(TFLITE_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libraries/tensorflowlite/headers")
set(TFLITE_GPU_LIBPATH "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libraries/tensorflowlite-gpu/jni")
set(TFLITE_GPU_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libraries/tensorflowlite-gpu/headers")

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=gnu++14")
set(CMAKE_CXX_STANDARD 14)

add_library(SuperResolution SHARED SuperResolution_jni.cpp SuperResolution.cpp)

add_library(lib_tensorflowlite SHARED IMPORTED)
set_target_properties(lib_tensorflowlite PROPERTIES IMPORTED_LOCATION
${TFLITE_LIBPATH}/${ANDROID_ABI}/libtensorflowlite_jni.so)

add_library(lib_tensorflowlite_gpu SHARED IMPORTED)
set_target_properties(lib_tensorflowlite_gpu PROPERTIES IMPORTED_LOCATION
${TFLITE_GPU_LIBPATH}/${ANDROID_ABI}/libtensorflowlite_gpu_jni.so)

find_library(log-lib log)

include_directories(${TFLITE_INCLUDE})
target_include_directories(SuperResolution PRIVATE
${TFLITE_INCLUDE})

include_directories(${TFLITE_GPU_INCLUDE})

target_include_directories(SuperResolution PRIVATE
${TFLITE_GPU_INCLUDE})

target_link_libraries(SuperResolution
android
lib_tensorflowlite
lib_tensorflowlite_gpu
${log-lib})


我们在 app 里包含了3个示例图片,这样用户可能会运行同一个模型多次,这意味着为了提高运行效率,我们需要将 TFLite 解释执行器进行缓存。这一点我们可以在解释执行器成功建立后通过将其指针从 C++ 传回到 Java 来实现:

extern "C" JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_examples_superresolution_MainActivity_initWithByteBufferFromJNI(JNIEnv *env, jobject thiz, jobject model_buffer, jboolean use_gpu)
{
const void *model_data = static_cast<void *>(env->GetDirectBufferAddress(model_buffer));
jlong model_size_bytes = env->GetDirectBufferCapacity(model_buffer);
SuperResolution *super_resolution = new SuperResolution(model_data, static_cast<size_t>(model_size_bytes), use_gpu);
if (super_resolution->IsInterpreterCreated()) {
LOGI("Interpreter is created successfully");
return reinterpret_cast<jlong>(super_resolution);
} else {
delete super_resolution;
return 0;
}
}


解释执行器建立之后,运行模型实际上就非常简单了,我们只需要按照 TFLite C API 来就好。不过我们需要注意的是如何从每个像素中抽取 RGB 值:

// Extract RGB values from each pixel
float input_buffer[kNumberOfInputPixels * kImageChannels];
for (int i = 0, j = 0; i < kNumberOfInputPixels; i++) {
// Alpha is ignored
input_buffer[j++] = static_cast<float>((lr_img_rgb[i] >> 16) & 0xff);
input_buffer[j++] = static_cast<float>((lr_img_rgb[i] >> 8) & 0xff);
input_buffer[j++] = static_cast<float>((lr_img_rgb[i]) & 0xff);
}
  • TFLite C API 
    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/c/c_api.h


运行完模型后我们需要再将 RGB 值再打包进每个像素:

// Postprocess the output from TFLite
int clipped_output[kImageChannels];
auto rgb_colors = std::make_unique<int[]>(kNumberOfOutputPixels);
for (int i = 0; i < kNumberOfOutputPixels; i++) {
for (int j = 0; j < kImageChannels; j++) {
clipped_output[j] = std::max<float>(0, std::min<float>(255, output_buffer[i * kImageChannels + j]));
}
// When we have RGB values, we pack them into a single pixel.
// Alpha is set to 255.
rgb_colors[i] = (255u & 0xff) << 24 | (clipped_output[0] & 0xff) << 16 |
(clipped_output[1] & 0xff) << 8 |
(clipped_output[2] & 0xff);
}


那么到这里我们就完成了这个 app 的关键步骤,我们可以用这个 app 来生成超分图片。您可以在对应的代码库中看到更多信息。我们希望这个范例能作为一个好的参考来帮助刚刚起步的开发者更快的掌握如何使用 TFLite C/C++ API 来搭建自己的机器学习 app。

  • 对应的代码库中
    https://github.com/tensorflow/examples/tree/master/lite/examples/super_resolution



致谢

作者十分感谢 @captain__pool 将他实现的 ESRGAN 模型上传到 TFHub, 以及 TFLite 团队的 Tian Lin 和 Jared Duke 提供十分有帮助的反馈。



— 参考 —

[1] Christian Ledig, Lucas Theis, Ferenc Huszar, Jose Caballero, Andrew Cunningham, Alejandro Acosta, Andrew Aitken, Alykhan Tejani, Johannes Totz, Zehan Wang, Wenzhe Shi. 2016. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network.

[2] Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Chen Change Loy, Yu Qiao, Xiaoou Tang. 2018. ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

[3] Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

  • https://github.com/krasserm/super-resolution

[4] @captain__pool 的 ESGRAN 代码实现

  • https://github.com/captain-pool/GSOC

[5] Eirikur Agustsson, Radu Timofte. 2017. NTIRE 2017 Challenge on Single Image Super-Resolution: Dataset and Study.



— 推荐阅读 —



了解更多请点击 “阅读原文” 访问 GitHub 代码库。

分享 💬  点赞 👍  在看 ❤️ 

以“三连”行动支持优质内容!

登录查看更多
1

相关内容

【ECCV2020】基于场景图分解的自然语言描述生成
专知会员服务
23+阅读 · 2020年9月3日
专知会员服务
108+阅读 · 2020年8月28日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
69+阅读 · 2020年1月17日
【GitHub实战】Pytorch实现的小样本逼真的视频到视频转换
专知会员服务
35+阅读 · 2019年12月15日
【机器学习课程】Google机器学习速成课程
专知会员服务
164+阅读 · 2019年12月2日
TensorFlow 2.0 学习资源汇总
专知会员服务
66+阅读 · 2019年10月9日
TensorFlow Lite 2019 年发展蓝图
谷歌开发者
6+阅读 · 2019年3月12日
教程帖:深度学习模型的部署
论智
8+阅读 · 2018年1月20日
Arxiv
0+阅读 · 2021年1月29日
Arxiv
3+阅读 · 2020年4月29日
FIGR: Few-shot Image Generation with Reptile
Arxiv
5+阅读 · 2019年1月8日
Arxiv
3+阅读 · 2018年3月21日
VIP会员
Top
微信扫码咨询专知VIP会员