博客 | CIFAR10 数据预处理

2018 年 10 月 12 日 AI研习社

本系列文章已由作者授权在AI研习社发布。

欢迎关注我的AI研习社博客:

http://www.gair.link/page/center/myPage/5104751

或订阅我的 CSDN:

https://blog.csdn.net/Kuo_Jun_Lin


  Brief 概述

在上一章中我们使用了 MNIST 手写数字数据集,套入一个非常简单的线性模型中,得到了大约 90% 左右的正确率,用意在于熟悉神经网络节点的架构和框架的使用方法,接下来这章将把前一章的数据集和方法全面提升一个档次,使用的是 CIFAR10 与 CNN 卷积神经网络的架构,同时也可以做为探讨深层神经网络如 VGG19,GoogleNet,与 ResNet 的敲门砖。

CNN 卷积神经网络假设大家已经有一个大致的了解,它不像线性回归的方法,从每个像素着手发现归类到不同标签的规则,而是使用卷积核逐步扫描整张图片的方式抽取出图像特征,经过逐个卷积和在逐层维度上特征的抽离处理,最终把他们与全连阶层相连,通往标签的归类,但是说着简单,其实操作上还有许多细节需要注意如下面几点:

  1. 借鉴上一次的代码运行过程,首先第一件事就是减少「类」中函数的冗长定义,因为每次呼叫类的方法时,其实类中的内容都会被重新刷新一遍,多次反复下来就是一个有负担的计算量。

  2. 图片不再是简单的手写数字,CIFAR10 有背景与对应标签的图案,因此为了更好的训练,图片需要做预处理,随机的:旋转角度,灰阶度,对比度,图像尺寸调整,明暗度,色调,与裁切,都是可以尝试的手法。

下面我们将探讨一个数据集在多个维度的比较,并尝试出最好正确率的排列组合。

p.s. 卷积神经网络搭建开始之前必须先确定自身电脑内存是否 >= 8G,虽然这个网络在 CNN 算法中非常简单,但如果从最开始的神经网络加总,一共也会有几十万个参数的量,需要注意电脑是否能够承载。

Code[1]

import sys
import tensorflow as tf
print(tf.__version__)

1.10.1


  CIFAR10 Dataset

它是一个内涵六万张图片的数据集,比起 MNIST,它的通道数是三个,用来表示其彩色的画面,并且图像尺寸是 32*32,其中分成训练集五万张与测试集一万张,制作人在打包数据的时候分了几个文档如下图:


其内部排列方式为一个大的字典,图片数据对应 'data' 字典键,标签数据对应 'labels' 字典键,而单张图片数据排布方式为一个一维列表:[...1024 red ... ...1024 green... ...1024 blue...],读取的方式可以直接点击 官网网址

为了使自己能够更加熟悉数据集内部结构的解析,同时 CIFAR10 官网只告知了打开它们数据集的方法,我们需要如使用 MNIST 的情况一样开始自己定义我们所需要的函数,不外乎数据读取,数据格式转换,one_hot 等大类,步骤如下:

1. Define functions without being iterated with class

定义的函数分别在如下陈列:

  • time_counter(): 是一个装饰器,功能是用来计时一个函数启动的时间

  • one_hot(): 用来把标签转换成 one hot 形式,方便后面神经网络归类匹配使用

  • get_random_batch(): 随机抽取样本做为一个簇后,方便小批量训练

Code[2]

    import time
    import numpy as np

    # To set a decorator used to count the time a func spent.
    def time_counter(func):
       # In order to count many func's time, arguments should be *args and **kwargs
       def wrapper(*args, **kwargs):
           t1 = time.time()
           result = func(*args, **kwargs)
           t2 = time.time() - t1
           print('Took {0:.4} sec to run "{1}" func'.format(t2, func.__name__))
           return result
       return wrapper

    # To convert the number labels into one hot mode respectively.
    def one_hot(labels, class_num=10):
       convert = np.eye(class_num, dtype=float)[labels]
       return convert

    # To get a random batch so that we can easily put data to train a model.
    def get_random_batch(data, batch_size=32):
       random = np.random.randint(0, len(data), size=batch_size)
       return data[random]


    2. Define a class used to well organized take apart the dataset

    由于数据是呈现 5 个批次储存,其中的函数设定我希望把他们融合成一块,后面处理和调用也表方便,并且其图片大小为 32x32 的尺寸,并不至于大到没办法一次容纳,因此设置的函数方法如下陈列:

    • load_binary_data(): 把二进制数据读取出来,并依照字典键的要求给出一个 numpy 数组的结果,方便后面数据处理

    • merge_batches(): 把全部批次的训练集数据全部融合起来成为一个大的数组

    • set_validation(): 设置一个验证集在训练集的比例,如果有不同的模型搭建可能会用到此功能

    • format_images(): 把一个 1D 向量表示的数据转换成卷积方法需要用到的 4D 格式(Batch, Height, Width, Channels)

    Code[3]

    # pickle is the module to open cifar10 dataset
    import pickle
    import os, sys

    # This class is used to refer the arranged content of CIFAR10 dataset
    class CIFAR10:
       # The unchangeable variables should be set here.
       image_size = 32
       image_channels = 3
       
       def __init__(self, val_ratio=0.1, data_dir='cifar-10-batches-py'):
           # Validation set can also be set if it is necessary for other purposes
           self.val_ratio = val_ratio
           self.data_dir = data_dir
           
           # Get the overall images data "without formatting"!
           self.img_train = self.merge_batches('data')
           self.img_train_main, self.img_train_val = self.set_validation(self.img_train)
           
           self.lab_train = self.merge_batches('labels')
           self.lab_train_main, self.lab_train_val = self.set_validation(self.lab_train)
           
           self.img_test = self.load_binary_data('test_batch', 'data') / 255.0
           self.lab_test = self.load_binary_data('test_batch', 'labels').astype(np.int)
           
       # The data format is binary mode and we should load them with pickle module
       # which is introduced at the official web page.
       def load_binary_data(self, file_name, dic_key):
           path = os.path.join(self.data_dir, file_name)
           with open(path, 'rb') as file:
               dic = pickle.load(file, encoding='bytes')
           
           # Those binary data are all contained by a dictionary also with
           # binary type of dictionary key. The returned list should also be
           # converted into np.array so that it can be indexed conveniently.
           try:
               dic_key = dic_key.encode(encoding='utf-8')
               return np.array(dic[dic_key])
           except:
               print('dic_key argument accepts only 4 keys as follow:\n',
                     '1.batch_label ; 2.labels ; 3.data ; 4.filenames')
           
       # There are five separated images dataset and we will want to
       # depose of them all at once.
       def merge_batches(self, dic_key):
           merge = []
           for i in range(5):
               filename = 'data_batch_{}'.format(i+1)
               data = self.load_binary_data(filename, dic_key)
               merge.append(data)
               np_merge = np.array(merge)
               
           if dic_key == 'data':
               length = self.image_size * self.image_size * self.image_channels
               np_merge = np_merge.reshape(5*len(data), length)
               return np.array(np_merge) / 255.0
           else:
               np_merge = np_merge.reshape(5*len(data))
               return np.array(np_merge).astype(np.int)
               
       def set_validation(self, data):
           val_set = round(len(data) * self.val_ratio)
           
           val_data = data[0:val_set]
           main_data = data[val_set:]
           return [main_data, val_data]
               
       # The 1D array representing an image should be converted to the format
       # that is as same as the regular image format (H, W, C)
       def format_images(self, images_flat):
           # The format of original data has (10000, 3072) shape matrix
           # with conjoint red 1024, green 1024, blue 1024.
           images = images_flat.reshape([-1, self.image_channels,
                                         self.image_size, self.image_size])
           # when depositing images, channels should stay at the last dimension.
           images = images.transpose([0, 2, 3, 1])
           return images
       
       @property
       def get_class_names(self):
           path = os.path.join(self.data_dir, 'batches.meta')
           
           with open(path, 'rb') as file:
               dic = pickle.load(file, encoding='bytes')
           class_names = [w.decode('utf-8') for w in dic[b'label_names']]
           
           for num, label in enumerate(class_names):
               print('{}: {}'.format(num, label))
           return class_names
       
       @property
       def num_per_batch(self):
           path = os.path.join(self.data_dir, 'batches.meta')
           with open(path, 'rb') as file:
               dic = pickle.load(file, encoding='bytes')
           return dic[b'num_cases_per_batch']
           
    path = input('The directory of CIFAR10 dataset: ')
    cifar = CIFAR10(data_dir=path)
    cifar.get_class_names
    print("Number per batch: {}".format(cifar.num_per_batch))

    The directory of CIFAR10 dataset: /Users/kcl/Documents/Python_Projects/cifar-10-batches-py
    0: airplane
    1: automobile
    2: bird
    3: cat
    4: deer
    5: dog
    6: frog
    7: horse
    8: ship
    9: truck
    Number per batch: 10000

    3. Print Images and Labels respectively

    为了验证导入的数据集是否与标签匹配,避免在模型训练前数据集基础就已经歪得一塌糊涂,结合了上面定义的 .format_images() 方法与 get_random_batch() 函数套入以下定义的绘图函数中,随机抽样查看数据匹配的完整性,代码如下:

    Code[4]

    import matplotlib.pyplot as plt

    images_flat_train = cifar.img_train
    images_train = cifar.format_images(images_flat_train)
    labels_train = cifar.lab_train

    images_flat_test = cifar.img_test
    images_test = cifar.format_images(images_flat_test)
    labels_test = cifar.lab_test

    # To define a universal purpose oriented plotting function here.
    # It should not only be able to plot correct images, but also is
    # capable of plotting the predicted labels.
    def plot_images(images, labels, lab_names, size=[3, 3],
                   pred_labels=False, random=True, smooth=True)
    :

       fig, axes = plt.subplots(size[0], size[1])
       fig.subplots_adjust(hspace=0.6, wspace=0.6)

       for n, ax in enumerate(axes.flat):
           # To decide if the printed images should be smooth or not.
           if smooth:
               interpolation = 'spline16'
           else:
               interpolation = 'nearest'
             
           # To decide if the images should be randomly picked up.
           if random:
               i = np.random.randint(0, len(labels), size=None, dtype=np.int)
           else:
               i = n
               
           ax.imshow(images[i], interpolation=interpolation)
           
           if pred_labels is False:
               xlabel = 'T: {}'.format(lab_names[labels[i]])
           else:
               xlabel = 'T: {0}\nP:{1}'.format(lab_names[labels[i]],
                                               lab_names[pred_labels[i]])
           ax.set_xlabel(xlabel)
           
           ax.set_xticks([])
           ax.set_yticks([])
       plt.show()
       
    plot_images(images_train, labels_train, cifar.get_class_names, size=[3, 5])

    0: airplane
    1: automobile
    2: bird
    3: cat
    4: deer
    5: dog
    6: frog
    7: horse
    8: ship
    9: truck



      Data Preprocessing 数据预处理

    如同概述部分提及的图像预处理步骤,接下来要使用下面 Tensorflow 所提供的方法来实现图像的随机改动:

    • tf.random_crop() 裁切

    • tf.image.random_flip_left_right() 左右镜像翻转

    • tf.image.random_flip_up_down 上下镜像翻转

    • tf.image.random_contrast() 对比度

    • tf.image.random_hue() 色调

    • tf.image.random_brightness() 明暗度

    • tf.image.random_saturation() 饱和度

    • tf.image.per_image_whitening() 图像数据标准化 ????

    • tf.image.resize_image_with_crop_or_pad() 重新定义图像尺寸(多了切掉少了补 0 )

    p.s. 还有很多 Tensorflow 框架支持的图像处理方法,点击此 查看官网

    对输入数据使用上面函数方法做改动就如同给数据集加了几个维度的数据,而丰富的数据集正是神经网络能够达到更高归类准确率的基本要素,同时还可减少过拟合的结果发生,换个角度思考这些产生的数据,它们就如同数据的噪声,为过拟合可能发生的情况提供了一道保险。


    不过当使用此方法在训练的时候,产生数据的过程会添加计算的负担,进而造成时间上的消耗,是我们应用此方法的时候一个重要的考虑要点。

    结合上述方法定义的函数代码如下:

    Code[5]

    def image_preprocessing(single_img, crop=[28, 28], crop_only=False):
       H, W = cifar.image_size, cifar.image_size
       height, width = crop
       
       single_img = tf.random_crop(single_img, size=[height, width, 3])
       single_img = tf.image.random_flip_left_right(single_img)
       single_img = tf.image.random_flip_up_down(single_img)
       single_img = tf.image.random_contrast(single_img, lower=0.5, upper=1.0)
       single_img = tf.image.random_hue(single_img, max_delta=0.03)
       single_img = tf.image.random_brightness(single_img, max_delta=0.2)
       single_img = tf.image.random_saturation(single_img, lower=0.5, upper=1.5)
       single_img = tf.minimum(single_img, 1.0)
       single_img = tf.maximum(single_img, 0.0)
       
       single_img = tf.image.resize_image_with_crop_or_pad(
           single_img, target_height=H, target_width=W)
       return single_img

    此函数的逻辑为下面陈列的几点说明:

    1. 调整我们要随机位置裁切的尺寸大小后

    2. 对裁切下来的图像开始随意颠倒,变化色调等等

    3. 把超出 RGB 三个单元最大值和最小值的部分抹平

    4. 把缩小尺寸的裁切团重新 padding 回到原本未裁切的大小,目的是使用数据流图时测试机不需要预处理图像就能够测试,此一做法更为合理


    上面定义的函数必须强调的是,它只处理 "单一张" 图片,如果关联到批量处理,例如我们习惯于把一整批图像数据用 4D 张量的方式表示,格式分别为 (张数,图高,图宽,颜色阶数),则可以使用 tf.map_fn 配合 lambda 的方式一次随机处理整批图像数据,并且每张图像数据的调整系数本身都不尽相同,最后面即为详细的搭配使用代码与说明。

    A glimpse to the Preprocessed Images

    为了确定我们处理的数据完整性与效果,下面尝试使用我们定义好的函数来随机打印预处理图片集的结果,步骤如下:

    1. 导入数据集,并使用定义的类方法呼叫训练图像

    2. 使用 Tensorflow 框架的构建方法,把导入的数据集放入我们预先定义好的函数中

    3. 启动 tf 会话 .Session() 功能

    4. sess.run() 了上个函数的运算结果后,才把这里的运算结果放入绘图函数中

    5. 等待时间约为一分半钟,预处理好后即自行打印

    Code[6]

      import tensorflow as tf

      lab_train = cifar.lab_train
      format_imgs = cifar.format_images(cifar.img_train)

      # We can put every single element of a list into the argument which
      # is belonging to tf.map_fn()'s fn by using lambda expression so it can
      # iterate all elements to the preset function "image_preprocessing".
      format_imgs = tf.map_fn(lambda img: image_preprocessing(img, crop=[24, 24]), format_imgs)

      sess = tf.Session()
      format_imgs = sess.run(format_imgs)
      plot_images(format_imgs, lab_train, cifar.get_class_names, size=[3, 4])

      0: airplane
      1: automobile
      2: bird
      3: cat
      4: deer
      5: dog
      6: frog
      7: horse
      8: ship
      9: truck


        文章回顾

      01.博客 | MNIST 数据集载入线性模型

      登录查看更多
      11

      相关内容

      干净的数据:数据清洗入门与实践,204页pdf
      专知会员服务
      161+阅读 · 2020年5月14日
      Sklearn 与 TensorFlow 机器学习实用指南,385页pdf
      专知会员服务
      129+阅读 · 2020年3月15日
      【Google AI】开源NoisyStudent:自监督图像分类
      专知会员服务
      54+阅读 · 2020年2月18日
      一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
      【模型泛化教程】标签平滑与Keras, TensorFlow,和深度学习
      专知会员服务
      20+阅读 · 2019年12月31日
      【机器学习课程】Google机器学习速成课程
      专知会员服务
      164+阅读 · 2019年12月2日
      误差反向传播——CNN
      统计学习与视觉计算组
      30+阅读 · 2018年7月12日
      入门 | 从VGG到NASNet,一文概览图像分类网络
      机器之心
      6+阅读 · 2018年4月2日
      keras系列︱深度学习五款常用的已训练模型
      数据挖掘入门与实战
      10+阅读 · 2018年3月27日
      一个小例子带你轻松Keras图像分类入门
      云栖社区
      4+阅读 · 2018年1月24日
      TensorFlow实现神经网络入门篇
      AI研习社
      11+阅读 · 2017年12月11日
      TensorFlow实例: 手写汉字识别
      机器学习研究会
      8+阅读 · 2017年11月10日
      实战|利用卷积神经网络处理CIFAR图像分类
      全球人工智能
      4+阅读 · 2017年7月22日
      Learning in the Frequency Domain
      Arxiv
      11+阅读 · 2020年3月12日
      Arxiv
      6+阅读 · 2018年6月20日
      Arxiv
      3+阅读 · 2018年3月2日
      VIP会员
      相关VIP内容
      相关资讯
      误差反向传播——CNN
      统计学习与视觉计算组
      30+阅读 · 2018年7月12日
      入门 | 从VGG到NASNet,一文概览图像分类网络
      机器之心
      6+阅读 · 2018年4月2日
      keras系列︱深度学习五款常用的已训练模型
      数据挖掘入门与实战
      10+阅读 · 2018年3月27日
      一个小例子带你轻松Keras图像分类入门
      云栖社区
      4+阅读 · 2018年1月24日
      TensorFlow实现神经网络入门篇
      AI研习社
      11+阅读 · 2017年12月11日
      TensorFlow实例: 手写汉字识别
      机器学习研究会
      8+阅读 · 2017年11月10日
      实战|利用卷积神经网络处理CIFAR图像分类
      全球人工智能
      4+阅读 · 2017年7月22日
      Top
      微信扫码咨询专知VIP会员