TensorFlow 2.0 常用模块3:tf.data

2019 年 11 月 8 日 TensorFlow

文 /  李锡涵,Google Developers Expert

本文节选自《简单粗暴 TensorFlow 2.0》

在上一篇文章中,我们介绍了 TensorBoard 的使用方式,本篇文章将介绍 tf.data  ,一套灵活的数据集构建 API,能够帮助我们快速、高效地构建数据输入的流水线。


本文内容包括:
  • tf.data.Dataset 数据集对象的建立;

  • 数据集对象的预处理(变换、打散、分批次等);

  • 数据集元素的获取与使用;

  • 一个完整的 tf.data 在图像分类任务的使用示例。


很多时候,我们希望使用自己的数据集来训练模型。然而,面对一堆格式不一的原始数据文件,将其预处理并读入程序的过程往往十分繁琐,甚至比模型的设计还要耗费精力。比如,为了读入一批图像文件,我们可能需要纠结于 python 的各种图像处理包(比如  pillow  ),自己设计 Batch 的生成方式,最后还可能在运行的效率上不尽如人意。为此,TensorFlow 提供了  tf.data  这一模块,包括了一套灵活的数据集构建 API,能够帮助我们快速、高效地构建数据输入的流水线,尤其适用于数据量巨大的场景。



数据集对象的建立

tf.data 的核心是  tf.data.Dataset 类,提供了对数据集的高层封装。 tf.data.Dataset 由一系列的可迭代访问的元素(element)组成,每个元素包含一个或多个张量。比如说,对于一个由图像组成的数据集,每个元素可以是一个形状为  长×宽×通道数 的图片张量,也可以是由图片张量和图片标签张量组成的元组(Tuple)。
最基础的建立  tf.data.Dataset 的方法是使用  tf.data.Dataset.from_tensor_slices() ,适用于数据量较小(能够整个装进内存)的情况。具体而言,如果我们的数据集中的所有元素通过张量的第 0 维,拼接成一个大的张量(例如,前节的 MNIST 数据集的训练集即为一个  [60000, 28, 28, 1] 的张量,表示了 60000 张 28*28 的单通道灰度图像),那么我们提供一个这样的张量或者第 0 维大小相同的多个张量作为输入,即可按张量的第 0 维展开来构建数据集,数据集的元素数量为张量第 0 位的大小。具体示例如下:
 1import tensorflow as tf
2import numpy as np
3
4X = tf.constant([20132014201520162017])
5Y = tf.constant([1200014000150001650017500])
6
7# 也可以使用NumPy数组,效果相同
8# X = np.array([2013, 2014, 2015, 2016, 2017])
9# Y = np.array([12000, 14000, 15000, 16500, 17500])
10
11dataset = tf.data.Dataset.from_tensor_slices((X, Y))
12
13for x, y in dataset:
14    print(x.numpy(), y.numpy()) 
输出:
12013 12000
22014 14000
32015 15000
42016 16500
52017 17500

警告
当提供多个张量作为输入时,张量的第 0 维大小必须相同,且必须将多个张量作为元组 (Tuple,即使用 Python 中的小括号) 拼接并作为输入。

类似地,我们可以载入前章的 MNIST 数据集:

 1import matplotlib.pyplot as plt 
2
3(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
4train_data = np.expand_dims(train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
5mnist_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
6
7for image, label in mnist_dataset:
8    plt.title(label.numpy())
9    plt.imshow(image.numpy()[:, :, 0])
10    plt.show()

输出:

提示1
TensorFlow Datasets 提供了一个基于 tf.data.Datasets 的开箱即用的数据集集合,相关内容可参考 TensorFlow Datasets (https://tf.wiki/zh/appendix/tfds.html) 。例如,使用以下语句:

1import tensorflow_datasets as tfds
2dataset = tfds.load("mnist", split=tfds.Split.TRAIN)
即可快速载入 MNIST 数据集。

提示 2
对于特别巨大而无法完整载入内存的数据集,我们可以先将数据集处理为 TFRecord 格式,然后使用 tf.data.TFRocrdDataset() 进行载入。我们会在后面的连载文章中介绍 TFRecord 格式的详细使用方式,或者你也可以参考 下文 以了解详情:
  • https://tensorflow.google.cn/tutorials/load_data/tfrecord 



数据集对象的预处理

tf.data.Dataset 类为我们提供了多种数据集预处理方法。最常用的如:
  • Dataset.map(f) :对数据集中的每个元素应用函数 f ,得到一个新的数据集(这部分往往结合 tf.io 进行读写和解码文件, tf.image 进行图像处理);

  • Dataset.shuffle(buffer_size) :将数据集打乱(设定一个固定大小的缓冲区(Buffer),取出前 buffer_size 个元素放入,并从缓冲区中随机采样,采样后的数据用后续数据替换);

  • Dataset.batch(batch_size) :将数据集分成批次,即对每 batch_size 个元素,使用 tf.stack() 在第 0 维合并,成为一个元素。

  • Dataset.prefetch() :预取出数据集中的若干个元素(可提升训练流程并行效率)。

除此以外,还有 Dataset.repeat() (重复数据集的元素)、 Dataset.reduce() (与 Map 相对的聚合操作)、 Dataset.take ()(截取数据集中的前若干个元素)等,可参考 API 文档 进一步了解。
注:API 文档 链接
https://tensorflow.google.cn/versions/r2.0/api_docs/python/tf/data/Dataset

以下以 MNIST 数据集进行示例。

使用 Dataset.map() 将所有图片旋转 90 度:

 1def rot90(image, label):
2    image = tf.image.rot90(image)
3    return image, label
4
5mnist_dataset = mnist_dataset.map(rot90)
6
7for image, label in mnist_dataset:
8    plt.title(label.numpy())
9    plt.imshow(image.numpy()[:, :, 0])
10    plt.show()

输出:

使用 Dataset.batch() 将数据集划分批次,每个批次的大小为 4:

1mnist_dataset = mnist_dataset.batch(4)
2
3for images, labels in mnist_dataset:    # image: [4, 28, 28, 1], labels: [4]
4    fig, axs = plt.subplots(14)
5    for i in range(4):
6        axs[i].set_title(labels.numpy()[i])
7        axs[i].imshow(images.numpy()[i, :, :, 0])
8    plt.show()

输出:

使用 Dataset.shuffle() 将数据打散后再设置批次,缓存大小设置为 10000:

1mnist_dataset = mnist_dataset.shuffle(buffer_size=10000).batch(4)
2
3for images, labels in mnist_dataset:
4    fig, axs = plt.subplots(14)
5    for i in range(4):
6        axs[i].set_title(labels.numpy()[i])
7        axs[i].imshow(images.numpy()[i, :, :, 0])
8    plt.show()

输出:

第一次运行
第二次运行

可见每次的数据都会被随机打散。

Dataset.shuffle() 时缓冲区大小 buffer_size 的设置
tf.data.Dataset 作为一个针对大规模数据设计的迭代器,本身无法方便地获得自身元素的数量或随机访问元素。因此,为了高效且较为充分地打散数据集,需要一些特定的方法。Dataset.shuffle() 采取了以下方法:

  • 设定一个固定大小为 buffer_size 的缓冲区(Buffer);

  • 初始化时,取出数据集中的前 buffer_size 个元素放入缓冲区;

  • 每次需要从数据集中取元素时,即从缓冲区中随机采样一个元素并取出,然后从后续的元素中取出一个放回到之前被取出的位置,以维持缓冲区的大小。

因此,缓冲区的大小需要根据数据集的特性和数据排列顺序特点来进行合理的设置。比如:

  • 当 buffer_size 设置为 1 时,其实等价于没有进行任何打散;

  • 当数据集的标签顺序分布极为不均匀(例如二元分类时数据集前 N 个的标签为 0,后 N 个的标签为 1)时,较小的缓冲区大小会使得训练时取出的 Batch 数据很可能全为同一标签,从而影响训练效果。一般而言,数据集的顺序分布若较为随机,则缓冲区的大小可较小,否则则需要设置较大的缓冲区。



数据集元素的获取与使用

构建好数据并预处理后,我们需要从其中迭代获取数据以用于训练。tf.data.Dataset 是一个 Python 的可迭代对象,因此可以使用 For 循环迭代获取数据,即:

1dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
2for a, b, c, ... in dataset:
3    # 对张量a, b, c等进行操作,例如送入模型进行训练

也可以使用 iter() 显式创建一个 Python 迭代器并使用 next() 获取下一个元素,即:

1dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
2it = iter(dataset)
3a_0, b_0, c_0, ... = next(it)
4a_1, b_1, c_1, ... = next(it)

Keras 支持使用 tf.data.Dataset 直接作为输入。当调用 tf.keras.Model  的 fit()  和  evaluate() 方法时,可以将参数中的输入数据 x 指定为一个元素格式为 (输入数据, 标签数据) 的 Dataset ,并忽略掉参数中的标签数据 y 。例如,对于上述的 MNIST 数据集,常规的 Keras 训练方式是:

1model.fit(x=train_data, y=train_label, epochs=num_epochs, batch_size=batch_size)

使用 tf.data.Dataset 后,我们可以直接传入 Dataset :

1model.fit(mnist_dataset, epochs=num_epochs)
由于已经通过  Dataset.batch() 方法划分了数据集的批次,所以这里也无需提供批次的大小。



实例:cats_vs_dogs 图像分类 

以下代码以猫狗图片二分类任务为示例,展示了使用 tf.data 结合 tf.io 和 tf.image 建立 tf.data.Dataset 数据集,并进行训练和测试的完整过程。数据集下载 (https://www.floydhub.com/fastai/datasets/cats-vs-dogs)

 1import tensorflow as tf
2import os
3
4num_epochs = 10
5batch_size = 32
6learning_rate = 0.001
7data_dir = 'C:/datasets/cats_vs_dogs'
8train_cats_dir = data_dir + '/train/cats/'
9train_dogs_dir = data_dir + '/train/dogs/'
10test_cats_dir = data_dir + '/valid/cats/'
11test_dogs_dir = data_dir + '/valid/dogs/'
12
13def _decode_and_resize(filename, label):
14    image_string = tf.io.read_file(filename)
15    image_decoded = tf.image.decode_jpeg(image_string)
16    image_resized = tf.image.resize(image_decoded, [256256]) / 255.0
17    return image_resized, label
18
19if __name__ == '__main__':
20    # 构建训练数据集
21    train_cat_filenames = tf.constant([train_cats_dir + filename for filename in os.listdir(train_cats_dir)])
22    train_dog_filenames = tf.constant([train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)])
23    train_filenames = tf.concat([train_cat_filenames, train_dog_filenames], axis=-1)
24    train_labels = tf.concat([
25        tf.zeros(train_cat_filenames.shape, dtype=tf.int32), 
26        tf.ones(train_dog_filenames.shape, dtype=tf.int32)], 
27        axis=-1)
28
29    train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))
30    train_dataset = train_dataset.map(_decode_and_resize)
31    # 取出前buffer_size个数据放入buffer,并从其中随机采样,采样后的数据用后续数据替换
32    train_dataset = train_dataset.shuffle(buffer_size=23000)    
33    train_dataset = train_dataset.batch(batch_size)
34
35    model = tf.keras.Sequential([
36        tf.keras.layers.Conv2D(323, activation='relu', input_shape=(2562563)),
37        tf.keras.layers.MaxPooling2D(),
38        tf.keras.layers.Conv2D(325, activation='relu'),
39        tf.keras.layers.MaxPooling2D(),
40        tf.keras.layers.Flatten(),
41        tf.keras.layers.Dense(64, activation='relu'),
42        tf.keras.layers.Dense(2, activation='softmax')
43    ])
44
45    model.compile(
46        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
47        loss=tf.keras.losses.sparse_categorical_crossentropy,
48        metrics=[tf.keras.metrics.sparse_categorical_accuracy]
49    )
50
51    model.fit(train_dataset, epochs=num_epochs)

使用以下代码进行测试:

 1    # 构建测试数据集
2    test_cat_filenames = tf.constant([test_cats_dir + filename for filename in os.listdir(test_cats_dir)])
3    test_dog_filenames = tf.constant([test_dogs_dir + filename for filename in os.listdir(test_dogs_dir)])
4    test_filenames = tf.concat([test_cat_filenames, test_dog_filenames], axis=-1)
5    test_labels = tf.concat([
6        tf.zeros(test_cat_filenames.shape, dtype=tf.int32), 
7        tf.ones(test_dog_filenames.shape, dtype=tf.int32)], 
8        axis=-1)
9
10    test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
11    test_dataset = test_dataset.map(_decode_and_resize)
12    test_dataset = test_dataset.batch(batch_size)
13
14    print(model.metrics_names)
15    print(model.evaluate(test_dataset))



福利 | 问答环节

我们知道在入门一项新的技术时有许多挑战与困难需要克服。如果您有关于 TensorFlow 的相关问题,可在本文后留言,我们的工程师和 GDE 将挑选其中具有代表性的问题在下一期进行回答~


在上一篇文章《TensorFlow 2.0 常用模块2:TensorBoard》中,我们对于部分具有代表性的问题回答如下:


Q1:2.0 是不是没有 slim 了?感觉 slim 模块中自带的 demo 挺好用的,2.0 中是否相应替代品?

A:TensorFlow 2.0 对大量 API 进行了合并优化以确保 API 的简洁性,部分作用类似的高层 API(如tf.layers)均在 TensorFlow 2.0 中移除。建议使用 Keras 以替代 Slim。可参考:

  • https://tensorflow.google.cn/guide/migrate#a_note_on_slim_contriblayers


Q2 既然 2.0 鼓励使用 eager 模式,那么 estimator 处于什么地位?

A:estimator 依旧会被保留并维护,且我们提供这个函数 tf.keras.estimator.model_to_estimator 将 Keras 模型直接转换成 estimator。


Q3:在 1.x 版本里,我们可以将模型转为 pb 后用 c++ 调用实现部署。在 2.0 中如何实现?

A:2.0 可以使用 Keras 导出 SavedModel,freeze 模型后使用 C++ 部署。


Q4:get_variable() 是弃用了吗?tf2.0 用这个函数报错呢?注意到《简单粗暴 tf》中使用了的,本文中改用了 Variable。

A:TensorFlow 2.0 中不再使用 get_variable() 建立或获取变量,转而使用更为面向对象(也更加自然易懂)的 tf.Variable 来建立变量。如果需要使用 1.X 的 API,可使用 tf.compat.v1.get_variable()。


Q5:请教一个问题,如何用 KerasClassifier 来实现 fit_generator 方法,网上的教程都是用的 fit 方法,如果数据集很大的时候,fit 方法就不合适了。是不是用 KerasClassifier 搜索超参数的时候,不能用 fit_generator 方法?毕竟在实际的应用中,图像大小不可能是像 mnist 中 28*28 那么大。

A:KerasClassifier 是一个 wrapper,目的是为了让接口跟 scikit 近似,没有 fit_generator 函数。如果数据集很大的时候,建议直接使用 Keras 而不是这个 wrapper。



《简单粗暴 TensorFlow 2.0 》目录


公众号回复关键字“手册”获取内容合集及 FAQ

登录查看更多
0

相关内容

数据集,又称为资料集、数据集合或资料集合,是一种由数据所组成的集合。
Data set(或dataset)是一个数据的集合,通常以表格形式出现。每一列代表一个特定变量。每一行都对应于某一成员的数据集的问题。它列出的价值观为每一个变量,如身高和体重的一个物体或价值的随机数。每个数值被称为数据资料。对应于行数,该数据集的数据可能包括一个或多个成员。
Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
专知会员服务
129+阅读 · 2020年3月15日
【模型泛化教程】标签平滑与Keras, TensorFlow,和深度学习
专知会员服务
20+阅读 · 2019年12月31日
【干货】谷歌Joshua Gordon 《TensorFlow 2.0讲解》,63页PPT
专知会员服务
27+阅读 · 2019年11月2日
社区分享 | Spark 玩转 TensorFlow 2.0
TensorFlow
15+阅读 · 2020年3月18日
tf.GradientTape 详解
TensorFlow
120+阅读 · 2020年2月21日
官方解读:TensorFlow 2.0 新的功能特性
云头条
3+阅读 · 2019年1月23日
TF Boys必看!一文搞懂TensorFlow 2.0新架构!
引力空间站
18+阅读 · 2019年1月16日
Tensorflow Eager Execution入门指南
专知
6+阅读 · 2018年4月16日
tensorflow系列笔记:流程,概念和代码解析
北京思腾合力科技有限公司
30+阅读 · 2017年11月11日
手把手教TensorFlow(附代码)
深度学习世界
15+阅读 · 2017年10月17日
Arxiv
8+阅读 · 2018年11月27日
Semantics of Data Mining Services in Cloud Computing
Arxiv
4+阅读 · 2018年10月5日
Bidirectional Attention for SQL Generation
Arxiv
4+阅读 · 2018年6月21日
Arxiv
3+阅读 · 2018年4月9日
Arxiv
3+阅读 · 2018年3月2日
Arxiv
4+阅读 · 2018年2月13日
Arxiv
6+阅读 · 2016年1月15日
VIP会员
相关资讯
社区分享 | Spark 玩转 TensorFlow 2.0
TensorFlow
15+阅读 · 2020年3月18日
tf.GradientTape 详解
TensorFlow
120+阅读 · 2020年2月21日
官方解读:TensorFlow 2.0 新的功能特性
云头条
3+阅读 · 2019年1月23日
TF Boys必看!一文搞懂TensorFlow 2.0新架构!
引力空间站
18+阅读 · 2019年1月16日
Tensorflow Eager Execution入门指南
专知
6+阅读 · 2018年4月16日
tensorflow系列笔记:流程,概念和代码解析
北京思腾合力科技有限公司
30+阅读 · 2017年11月11日
手把手教TensorFlow(附代码)
深度学习世界
15+阅读 · 2017年10月17日
相关论文
Arxiv
8+阅读 · 2018年11月27日
Semantics of Data Mining Services in Cloud Computing
Arxiv
4+阅读 · 2018年10月5日
Bidirectional Attention for SQL Generation
Arxiv
4+阅读 · 2018年6月21日
Arxiv
3+阅读 · 2018年4月9日
Arxiv
3+阅读 · 2018年3月2日
Arxiv
4+阅读 · 2018年2月13日
Arxiv
6+阅读 · 2016年1月15日
Top
微信扫码咨询专知VIP会员