Pytorch 数据流中常见Trick总结

2021 年 12 月 7 日 极市平台
↑ 点击 蓝字  关注极市平台

作者丨zlhroughlove@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/441317369
编辑丨极市平台

极市导读

 

本文对使用Pytorch建模过程中的一些思考以及遇到的问题做了一个总结,希望能给各位读者一个比较通用的模板~ >>加入极市CV技术交流群,走在计算机视觉的最前沿

前言

在使用Pytorch建模时,常见的流程为先写Model,再写Dataset,最后写Trainer。Dataset 是整个项目开发中投入时间第二多,也是中间关键的步骤。往往需要事先对于其设计有明确的思考,不然可能会因为Dataset的一些问题又要去调整Model,Trainer。本文将目前开发中的一些思考以及遇到的问题做一个总结,提供给各位读者一个比较通用的模版,抛砖引玉~

一、Dataset的定义

from torch.utils.data import Dataset, DataLoader, RandomSampler

对于不同类型的建模任务,模型的输入各不相同。自然语言,多模态,点击率预估,往往这些场景输入模型的数据并不是来自于单一文件,而且可能无法全部存入内存。Dataset需要整合项目的数据,对于单条样本涉及到的数据做一个提取与归纳。不但如此,项目可能还涉及到多种模型,任务的训练。Dataset需要为不同的模型以及训练任务提供不同的单条样本输入,作为一个数据生成器,把后续模型训练任务需要的所有基础数据,标签全返回了。所以往往我们可以定义一个BaseDataset类,继承torch.utils.data.Dataset,这个类可以初始化一些文件路径,配置等。后面不同的模型训练任务定义相应的Dataset类继承BaseDataset。

Dataset通用的结构为:

class BaseDataset(Dataset):

    def __init__(self, config):
        self.config = config
        if os.path.isfile(config.file_path) is False:
            raise ValueError(f"Input file path {config.file_path} not found")
        logger.info(f"Creating features from dataset file at {config.file_path}")
        # 一次性全读进内存
        self.data = joblib.load(config.file_path)
        self.nums = len(self.data)

    def __len__(self):
        return self.nums

    def __getitem__(self, i) -> Dict[str, tensor]:
        sample_i = self.data[i]
        return {"f1":torch.tensor(sample_i["f1"]).long(),"f2":torch.tensor(sample_i["f2"]).long(),torch.LongTensor([sample_i["label"]])}

如果无法全部读取进内存需要再__getitem__方法内构建数据,做自然语言则可以吧tokenizer初始化到该类中,在__getitem__方法内完成tokenizer。改方法的输出推荐做成字典形式。

对于不同的训练任务可以通过以下方法返回响应的数据生成器

def build_dataset(task_type, features, **kwargs):
    assert task_type in ['task1''task2'], 'task mismatch'

    if task_type == 'task1':
        dataset = task1Dataset(features))
    else:
        dataset = task2Dataset(features)

    return dataset

有时模型的训练任务需要做数据增强,对比学习,构造多种的预训练任务输入。Dataset的职能边界是提供一套基础的单样本数据输入生成器。如果是MLM任务,可以在Dataset内生成maskposition以及label。如果是在batch内的对比学习则应该在DataLoader生产batch数据后再进行。

二、DataLoader的定义

DataLoader的作用是对Dataset进行多进程高效地构建每个训练批次的数据。传入的数据可以认为是长度为batch大小的多个__getitem__ 方法返回的字典list。DataLoader的职能边界是根据Dataset提供的单条样本数据有选择的构建一个batch的模型输入数据。

其通常的结构为对Train,Valid,Test分别建立:

train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=None, # 一般不用设置
                              num_workers=4)

首先对于sampler 还有一种定义方式:

sampler = torch.utils.data.distributed.DistributedSampler(dataset)

至于batch内数据是否需要做shuffle也需要根据损失函数确定(对比学习慎用)

DataLoader会自动合并__getitem__ 方法返回的字典内每个key内每个tensor,在tensor的第0维度新增一个batch大小的维度。如果该方法返回的每条样本长度不同无法拼接,batchsize>1就会报错。但是又一些任务在还没有确定后续的批样本对应的任务时,Dataset可能返回的字典里每个key可能就是长度不同的tensor,甚至是list,这时候需要使用collate_fn参数告诉DataLoader如何取样。我们可以定义自己的函数来准确地实现想要的功能。

如果__getitem__方法返回的是tuple((list, list)) 可以使用:

def merge_sample(x):
    return zip(*x)

train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.train_batch_size,
                              sampler=train_sampler,
                              shuffle=(train_sampler is None)
                              collate_fn=merge_sample,
                              num_workers=4)

拼接数据,后续再做进一步处理。(此时list内数据还是不等长,无法转为tensor)

如果__getitem_方法返回的是Dict[str,tensor],自定义的collate_fn方法内需要实现:List[Dict[str,tensor(xx)]]->Dict[str,tensor(bs,xx)]的操作,pad_sequence过程也可以在自定义方法内实现。(总之collate_fn中不但可以处理不等长数据,还可以对一个batch的数据做精修。当然也可以在DataLoader之后再做修改batch内的数据。)

值得注意的是在cpu环境下,如果要自定义collate_fn,num_workers必须设置为0,不然就会有问题..

通过以下方式可以检查一下输入后续模型的数据是否已经是想要的格式

for step, batch_data in enumerate(train_loader):
    if step < 1:
        print(batch_data)
    else:
        break

之后数据将数据放入gpu device, 一个batch的数据进入device端后就与内存上的数据不再互相干扰。之后数据就可以喂给模型了:

for key in batch_data.keys():
    batch_data[key] = batch_data[key].to(device)
loss = model(**batch_data)

如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取 最新CV干货

公众号后台回复“transformer”获取最新Transformer综述论文下载~


极市干货
课程/比赛: 珠港澳人工智能算法大赛 保姆级零基础人工智能教程
算法trick 目标检测比赛中的tricks集锦 从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述: 一文弄懂各种loss function 工业图像异常检测最新研究总结(2019-2020)


CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart4)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~



觉得有用麻烦给个在看啦~   
登录查看更多
0

相关内容

数据集,又称为资料集、数据集合或资料集合,是一种由数据所组成的集合。
Data set(或dataset)是一个数据的集合,通常以表格形式出现。每一列代表一个特定变量。每一行都对应于某一成员的数据集的问题。它列出的价值观为每一个变量,如身高和体重的一个物体或价值的随机数。每个数值被称为数据资料。对应于行数,该数据集的数据可能包括一个或多个成员。
【干货书】Python金融分析,714页pdf掌握数据驱动金融
专知会员服务
94+阅读 · 2021年12月17日
专知会员服务
52+阅读 · 2021年6月17日
专知会员服务
104+阅读 · 2021年5月19日
【干货书】面向计算科学和工程的Python导论,167页pdf
专知会员服务
41+阅读 · 2021年4月7日
【干货书】PyTorch实战-一个解决问题的方法
专知会员服务
144+阅读 · 2021年4月2日
Python编程基础,121页ppt
专知会员服务
48+阅读 · 2021年1月1日
干净的数据:数据清洗入门与实践,204页pdf
专知会员服务
161+阅读 · 2020年5月14日
实例:手写 CUDA 算子,让 Pytorch 提速 20 倍
极市平台
4+阅读 · 2022年3月8日
详解PyTorch编译并调用自定义CUDA算子的三种方式
极市平台
0+阅读 · 2021年11月6日
图神经网络库PyTorch geometric
图与推荐
17+阅读 · 2020年3月22日
PyTorch模型训练特征图可视化(TensorboardX)
极市平台
33+阅读 · 2019年6月29日
PyTorch 学习笔记(一):让PyTorch读取你的数据集
极市平台
16+阅读 · 2019年4月24日
用PyTorch做物体检测和追踪
AI研习社
12+阅读 · 2019年1月6日
国家自然科学基金
4+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
国家自然科学基金
0+阅读 · 2008年12月31日
Arxiv
0+阅读 · 2022年4月19日
Arxiv
0+阅读 · 2022年4月19日
Arxiv
0+阅读 · 2022年4月18日
Automated Data Augmentations for Graph Classification
Arxiv
19+阅读 · 2018年10月25日
VIP会员
相关VIP内容
【干货书】Python金融分析,714页pdf掌握数据驱动金融
专知会员服务
94+阅读 · 2021年12月17日
专知会员服务
52+阅读 · 2021年6月17日
专知会员服务
104+阅读 · 2021年5月19日
【干货书】面向计算科学和工程的Python导论,167页pdf
专知会员服务
41+阅读 · 2021年4月7日
【干货书】PyTorch实战-一个解决问题的方法
专知会员服务
144+阅读 · 2021年4月2日
Python编程基础,121页ppt
专知会员服务
48+阅读 · 2021年1月1日
干净的数据:数据清洗入门与实践,204页pdf
专知会员服务
161+阅读 · 2020年5月14日
相关资讯
实例:手写 CUDA 算子,让 Pytorch 提速 20 倍
极市平台
4+阅读 · 2022年3月8日
详解PyTorch编译并调用自定义CUDA算子的三种方式
极市平台
0+阅读 · 2021年11月6日
图神经网络库PyTorch geometric
图与推荐
17+阅读 · 2020年3月22日
PyTorch模型训练特征图可视化(TensorboardX)
极市平台
33+阅读 · 2019年6月29日
PyTorch 学习笔记(一):让PyTorch读取你的数据集
极市平台
16+阅读 · 2019年4月24日
用PyTorch做物体检测和追踪
AI研习社
12+阅读 · 2019年1月6日
相关基金
国家自然科学基金
4+阅读 · 2015年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
1+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
国家自然科学基金
0+阅读 · 2008年12月31日
Top
微信扫码咨询专知VIP会员