文 / 李锡涵,Google Developers Expert
本文节选自《简单粗暴 TensorFlow 2.0》,回复 “手册” 获取合集
在上一篇文章里,我们介绍了在 TensorFlow 2.0 中进行多 GPU/多机分布式训练的方式。本篇文章将介绍快速载入数据集的利器 —— TensorFlow Datasets。
TensorFlow Datasets 数据集载入
TensorFlow Datasets 是一个开箱即用的数据集集合,包含数十种常用的机器学习数据集。通过简单的几行代码即可将数据以 tf.data.Dataset
的格式载入。关于 tf.data.Dataset
的使用可参考 tf.data。
TensorFlow Datasets
https://tensorflow.google.cn/datasets/
该工具是一个独立的 Python 包,可以通过:
pip install tensorflow-datasets
安装。
在使用时,首先使用 import 导入该包:
import tensorflow as tf
import tensorflow_datasets as tfds
然后,最基础的用法是使用 tfds.load
方法,载入所需的数据集。例如,以下三行代码分别载入了 MNIST、猫狗分类和 tf_flowers
三个图像分类数据集:
dataset = tfds.load("mnist", split=tfds.Split.TRAIN)
dataset = tfds.load("cats_vs_dogs", split=tfds.Split.TRAIN, as_supervised=True)
dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead set data_dir=gs://tfds-data/datasets.
Dl Completed...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00, 2.93s/ file]
Dl Completed...: 100%|██████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:10<00:00, 2.73s/ file]
Dataset mnist downloaded and prepared to C:\Users\snowkylin\tensorflow_datasets\mnist\3.0.0. Subsequent calls will reuse this data.
tfds.load
方法返回一个
tf.data.Dataset
对象。部分重要的参数如下:
as_supervised
:若为 True,则根据数据集的特性返回为
(input, label)
格式,否则返回所有特征组成的字典。
split
:指定返回数据集的特定部分,若不指定,则返回整个数据集。一般有 tfds.Split.TRAIN
(训练集)和 tfds.Split.TEST
(测试集)选项。
当前支持的数据集可在 官方文档 查看,或使用 tfds.list_builders()
查看。
官方文档
https://tensorflow.google.cn/datasets/datasets
tf.data.Dataset
类型的数据集后,我们即可使用
tf.data
对数据集进行各种预处理以及读取数据。例如:
# 使用 TessorFlow Datasets 载入“tf_flowers”数据集
dataset = tfds.load("tf_flowers", split=tfds.Split.TRAIN, as_supervised=True)
# 对 dataset 进行大小调整、打散和分批次操作
dataset = dataset.map(lambda img, label: (tf.image.resize(img, [224, 224]) / 255.0, label)) \
.shuffle(1024) \
.batch(32)
# 迭代数据
for images, labels in dataset:
# 对images和labels进行操作
提示
在使用 TensorFlow Datasets 时,可能需要设置代理。较为简易的方式是设置TFDS_HTTPS_PROXY
环境变量,即export TFDS_HTTPS_PROXY=http://代理服务器IP:端口
《简单粗暴 TensorFlow 2.0 》目录
TensorFlow 2.0 Datasets 数据集载入(本文)
有趣的人都在看