用一行tf.data实现数据Shuffle、Batch划分、异步预加载等


【导读】在早期的TensorFlow中,大规模训练数据的Shuffle、Batch划分和异步预加载等一直是开发的难点。tf.data.Dataset的出现统一了数据读取的规范,并提供了便利的接口。本文介绍动态图模式中Dataset的用法。


数据自动Shuffle和Batch自动划分


下面的代码直接以包含所有数据的numpy.ndarray作为数据,用tf.data.Dataset.from_tensor_slices(data).batch(50).shuffle(1000)即可自动实现对ndarray的随机打乱、划分等。

# coding=utf-8
import numpy as np
import tensorflow as tf
tf.enable_eager_execution()

# 随机生成1000个样本,每个包含20个特征
data = np.random.randn(10000, 20)

# data转换为输入
# 按照batch_size50划分batch
#
随机打乱数据,缓存大小为1000
for batch_data in tf.data.Dataset.from_tensor_slices(data).batch(50).shuffle(1000):
print(batch_data.shape)


运行结果:

(50, 20)
(50, 20)
(50, 20)
(50, 20)
......


自定义数据处理操作和异步预加载


上面的例子中,所有需要的数据都直接保存在了ndarray中,而在一些任务中,可能无法事先将所有的数据一次性读到ndarray中。例如如果数据集包含几十万张图片,一次性加载到内存中可能会导致内存消耗过大。这时就需要:

  • 事先只加载图片路径列表,而非真实的图片数据

  • 用tf.data.Dataset基于路径列表进行Shuffle和Batch划分等操作

  • 在tf.data.Dataset中加入自定义操作,将batch中的路径转换为对应的图片数据(通过实时读取)

  • 由于从磁盘读取图片效率较低,需要在训练过程中用后台线程预先读取后续batch包含的图片,这个过程叫prefetch


在下面的示例中,我们从一个叫data的文件夹中读取图片信息,将每张图片转换为像素Tensor,并进行缩放和归一化操作。其中还包含了Shuffle、Batch划分和异步预抓取:


# coding=utf-8
import os
import tensorflow as tf
from tensorflow.python.data.experimental import AUTOTUNE

tf.enable_eager_execution()

# 图片文件夹,包含多张图片
image_dir = "data"
# 获取每个图片的绝对路径
paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir)]


# 定义一个读取图片的方法
def read_image(image_path):
# 读取图片内容
content = tf.io.read_file(image_path)
# 读取图片像素
image_tensor = tf.image.decode_jpeg(content)
# 图片缩放
image_tensor = tf.image.resize(image_tensor, [20, 20])
# 归一化
image_tensor = tf.cast(image_tensor, tf.float32) / 255.0
return image_tensor


# 基于文件路径进行ShuffleBatch划分
# 按照batch_size50划分batch
#
随机打乱数据,缓存大小为1000
#
batch中的每个数据执行read_image操作(map映射)
# 即根据图片路径读取图片内容
# 使用prefetch进行异步预抓取
for batch_data in tf.data.Dataset.from_tensor_slices(paths)\
.map(read_image, num_parallel_calls=AUTOTUNE)\
.batch(50)\
.shuffle(1000)\
.prefetch(buffer_size=AUTOTUNE):
print(batch_data.shape)


运行结果:

(50, 20, 20, 3)
(50, 20, 20, 3)
(50, 20, 20, 3)


更多关于tf.data.Dataset的用法可以参考官方教程:

  • https://www.tensorflow.org/alpha/tutorials/load_data/images


-END-

专 · 知

专知《深度学习:算法到实战》课程全部完成!510+位同学在学习,现在报名,限时优惠!网易云课堂人工智能畅销榜首位!

欢迎微信扫一扫加入专知人工智能知识星球群,获取最新AI专业干货知识教程视频资料和与专家交流咨询!

请加专知小助手微信(扫一扫如下二维码添加),加入专知人工智能主题群,咨询《深度学习:算法到实战》课程,咨询技术商务合作~

请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料!

点击“阅读原文”,了解报名专知《深度学习:算法到实战》课程

展开全文
Top
微信扫码咨询专知VIP会员