社区分享 | 在 Flutter 中使用 TensorFlow Lite 插件实现文字分类

2020 年 9 月 16 日 TensorFlow

本文来自 Flutter 社区的投稿。


如果您希望能有一种简单、高效且灵活的方式把 TensorFlow 模型集成到 Flutter 应用里,那请您一定不要错过我们今天介绍的这个全新插件 tflite_flutter。这个插件的开发者是 Google Summer of Code (GSoC) 的一名实习生 Amish Garg。


tflite_flutter 插件的核心特性:
  • 插件提供了与 TFLite Java 和 Swift API 相似的 Dart API,所以其灵活性和在这些平台上的效果是完全一样的;
  • 插件通过 dart:ffi 直接与 TensorFlow Lite C API 相绑定,所以它比其它平台集成方式更加高效;
  • 无需编写特定平台的代码;
  • 通过 NNAPI 提供加速支持,在 Android 上使用 GPU Delegate,在 iOS 上使用 Metal Delegate。


本文中,我们将使用 tflite_flutter 构建一个文字分类 Flutter 应用,带您体验 tflite_flutter 插件。首先从新建一个 Flutter 项目text_classification_app开始。 



初始化配置

Linux 和 Mac 用户

install.sh 拷贝到您应用的根目录,然后在根目录执行 sh install.sh,本例中就是目录 text_classification_app/


Windows 用户

将 install.bat 文件拷贝到应用根目录,并在根目录运行批处理文件 install.bat,本例中就是目录 text_classification_app/


它会自动从 GitHub 仓库的 Releases 里下载最新的二进制资源,然后把它放到指定的目录下。


请点击到 README 文件里查看更多 关于初始配置的信息。

  • tflite_flutter 的 GitHub 仓库
    https://github.com/am15h/tflite_flutter_plugin



获取插件

在 pubspec.yaml 添加 tflite_flutter: ^<latest_version>

  • 最新版本情况参考插件的发布地址
    https://pub.flutter-io.cn/packages/tflite_flutter



下载模型

要在移动端上运行 TensorFlow 训练模型,我们需要使用 .tflite 格式。如果需要了解如何将 TensorFlow 训练的模型转换为 .tflite 格式,请参阅官方指南


这里我们准备使用 TensorFlow 官方站点上预训练的文字分类模型


该预训练的模型可以预测当前段落的情感是积极还是消极。它是基于来自 Mass 等人的  Large Movie Review Dataset v1.0 数据集进行训练的。数据集由基于 IMDB 电影评论所标记的积极或消极标签组成,查看更多信息。


text_classification.tflitetext_classification_vocab.txt 文件拷贝到 text_classification_app/assets/ 目录下。


pubspec.yaml 文件中添加 assets/

assets:
- assets/


现在万事俱备,我们可以开始写代码了。🚀

  • 模型转换器(Converter)的 Python API 指南
    https://tensorflow.google.cn/lite/convert/python_api

  • 预训练的文字分类模型 (text_classification.tflite)
    https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification.tflite

  • 数据集 (text_classification_vocab.txt)
    https://files.flutter-io.cn/posts/flutter-cn/2020/tensorflow-lite-plugin/text_classification_vocab.txt



实现分类器

预处理

正如 文字分类模型页面里所提到的。可以按照下面的步骤使用模型对段落进行分类:
  • 对段落文本进行分词,然后使用预定义的词汇集将它转换为一组词汇 ID;
  • 将生成的这组词汇 ID 输入 TensorFlow Lite 模型里;
  • 从模型的输出里获取当前段落是积极或者是消极的概率值。


我们首先写一个方法对原始字符串进行分词,其中使用 text_classification_vocab.txt作为词汇集。


lib/ 文件夹下创建一个新文件 classifier.dart


这里先写代码加载 text_classification_vocab.txt 到字典里。

import 'package:flutter/services.dart';

class Classifier {
final _vocabFile = 'text_classification_vocab.txt';

Map<String, int> _dict;

Classifier() {
_loadDictionary();
}

void _loadDictionary() async {
final vocab = await rootBundle.loadString('assets/$_vocabFile');
var dict = <String, int>{};
final vocabList = vocab.split('\n');
for (var i = 0; i < vocabList.length; i++) {
var entry = vocabList[i].trim().split(' ');
dict[entry[0]] = int.parse(entry[1]);
}
_dict = dict;
print('Dictionary loaded successfully');
}

}

△ 加载字典


现在我们来编写一个函数对原始字符串进行分词。

import 'package:flutter/services.dart';

class Classifier {
final _vocabFile = 'text_classification_vocab.txt';

// 单句的最大长度
final int _sentenceLen = 256;

final String start = '<START>';
final String pad = '<PAD>';
final String unk = '<UNKNOWN>';

Map<String, int> _dict;

List<List<double>> tokenizeInputText(String text) {

// 使用空格进行分词
final toks = text.split(' ');

// 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 的对应的字典值来填充
var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

var index = 0;
if (_dict.containsKey(start)) {
vec[index++] = _dict[start].toDouble();
}

// 对于句子里的每个单词,在映射里找到相应的索引值
for (var tok in toks) {
if (index > _sentenceLen) {
break;
}
vec[index++] = _dict.containsKey(tok)
? _dict[tok].toDouble()
:
_dict[unk].toDouble();
}

// 按照我们的解释器输入 tensor 所需的格式 [1, 256] 返回 List<List<double>>
return [vec];
}
}

△ 分词代码


使用 tflite_flutter 进行分析

这是本文的主体部分,这里我们会讨论 tflite_flutter 插件的用途。

此处的分析指的是在设备上基于输入的数据,使用 TensorFlow Lite 模型的处理过程。要使用 TensorFlow Lite 模型进行分析,需要通过解释器来运行它,了解更多


创建解释器,加载模型

tflite_flutter 提供了一个方法直接通过资源创建解释器。

static Future<Interpreter> fromAsset(String assetName, {InterpreterOptions options})


由于我们的模型在 assets/文件夹下,需要使用上面的方法来创建解析器。对于 InterpreterOptions 的相关说明,请参考这里

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
// 模型文件的名称
final _modelFile = 'text_classification.tflite';

// TensorFlow Lite 解释器对象
Interpreter _interpreter;

Classifier() {
// 当分类器初始化以后加载模型
_loadModel();
}

void _loadModel() async {

// 使用 Interpreter.fromAsset 创建解释器
_interpreter = await Interpreter.fromAsset(_modelFile);
print('Interpreter loaded successfully');
}

}

△ 创建解释器的代码


如果您不希望将模型放在 assets/ 目录下,tflite_flutter 还提供了工厂构造函数创建解释器,更多信息


我们开始进行分析!


现在用下面方法启动分析:

void run(Object input, Object output);


注意这里的方法和 Java API 中的是一样的。


Object inputObject output 必须是与 Input Tensor 和 Output Tensor 维度相同的列表。


要查看 input tensor 和 output tensor 的维度,可以使用如下代码:

_interpreter.allocateTensors();
// 打印 input tensor 列表
print(_interpreter.getInputTensors());
// 打印 output tensor 列表
print(_interpreter.getOutputTensors());


在本例中 text_classification 模型的输出如下:

InputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf280, name: embedding_input, type: TfLiteType.float32, shape: [1, 256], data: 1024]
OutputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf140, name: dense_1/Softmax, type: TfLiteType.float32, shape: [1, 2], data: 8]


现在,我们实现分类方法,该方法返回值为 1 表示积极,返回值为 0 表示消极。

int classify(String rawText) {

// tokenizeInputText 返回形状为 [1, 256] 的 List<List<double>>
List<List<double>> input = tokenizeInputText(rawText);

// [1,2] 形状的输出
var output = List<double>(2).reshape([1, 2]);

// run 方法会运行分析并且存储输出的值
_interpreter.run(input, output);

var result = 0;
// 如果输出中第一个元素的值比第二个大,那么句子就是消极的
if ((output[0][0] as double) > (output[0][1] as double)) {
result = 0;
} else {
result = 1;
}
return result;
}

△ 用于分析的代码


在 tflite_flutter 的 extension ListShape on List 下面定义了一些使用的扩展:

// 将提供的列表进行矩阵变形,输入参数为元素总数并保持相等
// 用法:List(400).reshape([2,10,20])
// 返回 List<dynamic>

List reshape(List<int> shape)
// 返回列表的形状
List<int> get shape
// 返回列表任意形状的元素数量
int get computeNumElements


最终的 classifier.dart 应该是这样的:

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
// 模型文件的名称
final _modelFile = 'text_classification.tflite';
final _vocabFile = 'text_classification_vocab.txt';

// 语句的最大长度
final int _sentenceLen = 256;

final String start = '<START>';
final String pad = '<PAD>';
final String unk = '<UNKNOWN>';

Map<String, int> _dict;

// TensorFlow Lite 解释器对象
Interpreter _interpreter;

Classifier() {
// 当分类器初始化的时候加载模型
_loadModel();
_loadDictionary();
}

void _loadModel() async {
// 使用 Intepreter.fromAsset 创建解析器
_interpreter = await Interpreter.fromAsset(_modelFile);
print('Interpreter loaded successfully');
}

void _loadDictionary() async {
final vocab = await rootBundle.loadString('assets/$_vocabFile');
var dict = <String, int>{};
final vocabList = vocab.split('\n');
for (var i = 0; i < vocabList.length; i++) {
var entry = vocabList[i].trim().split(' ');
dict[entry[0]] = int.parse(entry[1]);
}
_dict = dict;
print('Dictionary loaded successfully');
}

int classify(String rawText) {
// tokenizeInputText 返回形状为 [1, 256] 的 List<List<double>>
List<List<double>> input = tokenizeInputText(rawText);

//输出形状为 [1, 2] 的矩阵
var output = List<double>(2).reshape([1, 2]);

// run 方法会运行分析并且将结果存储在 output 中。
_interpreter.run(input, output);

var result = 0;
// 如果第一个元素的输出比第二个大,那么当前语句是消极的
if ((output[0][0] as double) > (output[0][1] as double)) {
result = 0;
} else {
result = 1;
}
return result;
}

List<List<double>> tokenizeInputText(String text) {
// 用空格分词
final toks = text.split(' ');

// 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 对应的字典值来填充
var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

var index = 0;
if (_dict.containsKey(start)) {
vec[index++] = _dict[start].toDouble();
}

// 对于句子中的每个单词,在 dict 中找到相应的 index
for (var tok in toks) {
if (index > _sentenceLen) {
break;
}
vec[index++] = _dict.containsKey(tok)
? _dict[tok].toDouble()
:
_dict[unk].toDouble();
}

// 按照我们的解释器输入 tensor 所需的形状 [1,256] 返回 List<List<double>>
return [vec];
}
}


现在,可以根据您的喜好实现 UI 的代码,分类器的用法比较简单。

// 创建 Classifier 对象
Classifer _classifier = Classifier();
// 将目标语句作为参数,调用 classify 方法
_classifier.classify("I liked the movie");
// 返回 1 (积极的)
_classifier.classify("I didn't liked the movie");
// 返回 0 (消极的)


请在这里查阅完整代码

  • Text Classification Example app with UI
    https://github.com/am15h/tflite_flutter_plugin/tree/master/example

△ 文字分类示例应用


了解更多关于 tflite_flutter 插件的信息,请访问 GitHub repo: am15h/tflite_flutter_plugin



你问我答

问:tflite_flutter 和 tflite v1.0.5 有哪些区别?

tflite v1.0.5 侧重于为特定用途的应用场景提供高级特性,比如图片分类、物体检测等等。而新的 tflite_flutter 则提供了与 Java API 相同的特性和灵活性,而且可以用于任何 tflite 模型中,它还支持 delegate。


由于使用 dart:ffi (dart ↔️ (ffi) ↔️ C),tflite_flutter 非常快 (拥有低延时)。而 tflite 使用平台集成 (dart ↔️ platform-channel ↔️ (Java/Swift) ↔️ JNI ↔️ C)。


问:如何使用 tflite_flutter 创建图片分类应用?有没有类似 TensorFlow Lite Android Support Library 的依赖包?

TensorFlow Lite Flutter Helper Library 为处理和控制输入及输出的 TFLite 模型提供了易用的架构。它的 API 设计和文档与 TensorFlow Lite Android Support Library 是一样的。更多信息请参考 TFLite Flutter Helper 的 GitHub

  • TFLite Flutter Helper 开发库 GitHub 仓库地址
    https://github.com/am15h/tflite_flutter_helper


以上是本文的全部内容,欢迎大家对 tflite_flutter 插件进行反馈,请在 GitHub 报 bug 或提出功能需求。谢谢关注,感谢 Flutter 团队的 Michael Thomsen。

  • 向 tflite_flutter 插件提出建议和反馈
    https://github.com/am15h/tflite_flutter_plugin/issues


阅读文中的链接,请点击阅读原文或者下面 URL 查看:

  • https://flutter.cn/community/tutorials/text-classification-using-tensorflow-lite-plugin-for-flutter



译者:Yuan,谷创字幕组

审校:Xinlei、Lynn Wang、Alex,CFUG 社区



—推荐阅读—



登录查看更多
0

相关内容

Google 发布的面向结构化 web 应用的开语言。 dartlang.org
【2020新书】使用Kubernetes开发高级平台,519页pdf
专知会员服务
66+阅读 · 2020年9月19日
专知会员服务
118+阅读 · 2020年7月22日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
69+阅读 · 2020年1月17日
【干货】用BRET进行多标签文本分类(附代码)
专知会员服务
84+阅读 · 2019年12月27日
【电子书】Flutter实战305页PDF免费下载
专知会员服务
22+阅读 · 2019年11月7日
美团:基于跨平台框架Flutter的动态化平台建设
前端之巅
14+阅读 · 2019年6月17日
TensorFlow 2.0如何在Colab中使用TensorBoard
专知
17+阅读 · 2019年3月15日
TensorFlow Lite 2019 年发展蓝图
谷歌开发者
6+阅读 · 2019年3月12日
CNN与RNN中文文本分类-基于TensorFlow 实现
七月在线实验室
13+阅读 · 2018年10月30日
收藏!CNN与RNN对中文文本进行分类--基于TENSORFLOW实现
全球人工智能
12+阅读 · 2018年5月26日
教程 | 如何使用TensorFlow实现音频分类任务
机器之心
5+阅读 · 2017年12月16日
如何用TensorFlow和TF-Slim实现图像标注、分类与分割
数据挖掘入门与实战
3+阅读 · 2017年11月17日
开源|基于tensorflow使用CNN-RNN进行中文文本分类!
全球人工智能
11+阅读 · 2017年11月12日
A Comprehensive Survey on Transfer Learning
Arxiv
121+阅读 · 2019年11月7日
VIP会员
相关资讯
美团:基于跨平台框架Flutter的动态化平台建设
前端之巅
14+阅读 · 2019年6月17日
TensorFlow 2.0如何在Colab中使用TensorBoard
专知
17+阅读 · 2019年3月15日
TensorFlow Lite 2019 年发展蓝图
谷歌开发者
6+阅读 · 2019年3月12日
CNN与RNN中文文本分类-基于TensorFlow 实现
七月在线实验室
13+阅读 · 2018年10月30日
收藏!CNN与RNN对中文文本进行分类--基于TENSORFLOW实现
全球人工智能
12+阅读 · 2018年5月26日
教程 | 如何使用TensorFlow实现音频分类任务
机器之心
5+阅读 · 2017年12月16日
如何用TensorFlow和TF-Slim实现图像标注、分类与分割
数据挖掘入门与实战
3+阅读 · 2017年11月17日
开源|基于tensorflow使用CNN-RNN进行中文文本分类!
全球人工智能
11+阅读 · 2017年11月12日
Top
微信扫码咨询专知VIP会员