一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

2022 年 7 月 20 日 极市平台
↑ 点击 蓝字  关注极市平台

作者丨marsggbo@知乎 (已授权)
来源丨https://zhuanlan.zhihu.com/p/76893455
编辑丨极市平台

极市导读

 

很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

以下内容都是针对Pytorch 1.0-1.1介绍。

自上而下理解三者关系

首先我们看一下DataLoader.__next__的源代码长什么样(代码链接:https://github.com/pytorch/pytorch/blob/0b868b19063645afed59d6d49aff1e43d1665b88/torch/utils/data/dataloader.py#L557-L563)。

为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。

class DataLoader(object):
    ...
    
    def __next__(self):
        if self.num_workers == 0:  
            indices = next(self.sample_iter)  # Sampler
            batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。

那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。

再下面的if语句的作用简单理解就是,如果pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

综上可以知道DataLoader,Sampler和Dataset三者关系如下:



在阅读后文的过程中,你始终需要将上面的关系记在心里,这样能帮助你更好地理解。

Sampler

参数传递

要更加细致地理解Sampler原理,我们需要先阅读一下DataLoader 的源代码,如下:

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

可以看到初始化参数里有两种sampler:sampler和batch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。例如下面示例中,BatchSampler将SequentialSampler生成的index按照指定的batch size分组。

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[012], [345], [678], [9]]

Pytorch中已经实现的Sampler有如下几种:

  • SequentialSampler
  • RandomSampler
  • WeightedSampler
  • SubsetRandomSampler

需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读源码更深地理解(https://github.com/pytorch/pytorch/blob/0b868b19063645afed59d6d49aff1e43d1665b88/torch/utils/data/dataloader.py#L157-L182)。这里只做总结:

  • 如果你自定义了batch_sampler,那么这些参数都必须使用默认值:batch_size, shuffle,sampler,drop_last.

  • 如果你自定义了sampler,那么shuffle需要设置为False

  • 如果sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:

    • 若shuffle=True,则sampler=RandomSampler(dataset)
    • 若shuffle=False,则sampler=SequentialSampler(dataset)

如何自定义Sampler和BatchSampler?

仔细查看源代码其实可以发现,所有采样器其实都继承自同一个父类,即Sampler,其代码定义如下:
  
  
    
class Sampler(object):
     r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """


     def __init__(self, data_source):
         pass

     def __iter__(self):
         raise NotImplementedError
        
     def __len__(self):
         return len(self.data_source)
所以你要做的就是定义好__iter__(self)函数,不过要注意的是该函数的返回值需要是可迭代的。例如SequentialSampler返回的是iter(range(len(self.data_source)))。
另外BatchSampler与其他Sampler的主要区别是它需要将Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表。也就是说在后面的读取数据过程中使用的都是batch sampler。

Dataset

Dataset定义方式如下:
  
  
    
class Dataset(object):
     def __init__(self):
        ...
        
     def __getitem__(self, index):
         return ...
    
     def __len__(self):
         return ...
上面三个方法是最基本的,其中__getitem__是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]来访问第一个数据。在此之前我一直没弄清楚__getitem__是什么作用,所以一直不知道该怎么进入到这个函数进行调试。
现在如果你想对__getitem__方法进行调试,你可以写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:
  
  
    
class DataLoader(object): 
    ... 
     
     def __next__(self): 
         if self.num_workers ==  0:   
            indices = next(self.sample_iter)  
            batch = self.collate_fn([self.dataset[i]  for i  in indices])  # this line 
             if self.pin_memory: 
                batch = _utils.pin_memory.pin_memory_batch(batch) 
             return batch
我们仔细看可以发现,前面还有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前我们需要知道每个参数的意义:
  • indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
  • self.dataset[i]: 前面已经介绍了,这里就是对第i个数据进行读取操作,一般来说self.dataset[i]=(img, label)
看到这不难猜出collate_fn的作用就是将一个batch的数据进行合并操作。默认的collate_fn是将img和label分别合并成imgs和labels,所以如果你的__getitem__方法只是返回 img, label,那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。

公众号后台回复“ECCV2022”获取论文分类资源下载~

△点击卡片关注极市平台,获取 最新CV干货

极市干货
算法项目: CV工业项目落地实战大航海 | 附极力值攻略 极市打榜|目标检测算法上新!(年均分成5万)
实操教程 Pytorch - 弹性训练原理分析《CUDA C 编程指南》导读
极视角动态: 极视角作为重点项目入选「2022青岛十大资本青睐企业」榜单! 极视角发布EQP激励计划,招募优质算法团队展开多维度生态合作!

点击阅读原文进入CV社区

收获更多技术干货


登录查看更多
1

相关内容

Python图像处理,366页pdf,Image Operators Image Processing in Python
【干货书】流畅Python,766页pdf,中英文版
专知会员服务
224+阅读 · 2020年3月22日
TensorFlow Lite指南实战《TensorFlow Lite A primer》,附48页PPT
专知会员服务
69+阅读 · 2020年1月17日
解决PyTorch半精度(AMP)训练nan问题
CVer
3+阅读 · 2022年1月4日
实践教程 | 浅谈 PyTorch 中的 tensor 及使用
极市平台
1+阅读 · 2021年12月14日
正则化技巧:标签平滑以及在 PyTorch 中的实现
极市平台
2+阅读 · 2021年12月10日
PyTorch 深度剖析:如何保存和加载PyTorch模型?
极市平台
0+阅读 · 2021年11月28日
图神经网络库PyTorch geometric
图与推荐
17+阅读 · 2020年3月22日
PyTorch模型训练特征图可视化(TensorboardX)
极市平台
33+阅读 · 2019年6月29日
如何给你PyTorch里的Dataloader打鸡血
极市平台
15+阅读 · 2019年5月21日
PyTorch 学习笔记(一):让PyTorch读取你的数据集
极市平台
16+阅读 · 2019年4月24日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Arxiv
0+阅读 · 2022年9月17日
已删除
Arxiv
32+阅读 · 2020年3月23日
VIP会员
相关资讯
解决PyTorch半精度(AMP)训练nan问题
CVer
3+阅读 · 2022年1月4日
实践教程 | 浅谈 PyTorch 中的 tensor 及使用
极市平台
1+阅读 · 2021年12月14日
正则化技巧:标签平滑以及在 PyTorch 中的实现
极市平台
2+阅读 · 2021年12月10日
PyTorch 深度剖析:如何保存和加载PyTorch模型?
极市平台
0+阅读 · 2021年11月28日
图神经网络库PyTorch geometric
图与推荐
17+阅读 · 2020年3月22日
PyTorch模型训练特征图可视化(TensorboardX)
极市平台
33+阅读 · 2019年6月29日
如何给你PyTorch里的Dataloader打鸡血
极市平台
15+阅读 · 2019年5月21日
PyTorch 学习笔记(一):让PyTorch读取你的数据集
极市平台
16+阅读 · 2019年4月24日
相关基金
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2013年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Top
微信扫码咨询专知VIP会员