PyTorch使用总览

2019 年 3 月 25 日 极市平台

极市正在计划做CVPR2019的专题直播分享会邀请CVPR2019的论文作者进行线上直播,分享优秀的科研工作和技术干货,也欢迎各位小伙伴自荐或推荐优秀的CVPR论文作者到极市进行技术分享~

本周四(3月28日)晚,澳大利亚阿德莱德大学博士生王鑫龙,将为我们分享联合点云分割中的实例和语义(CVPR2019),公众号回复“39”即可获取直播详情。

作者简介

魏凯峰:计算机视觉、深度学习、机器学习爱好者,CSDN博客专家“AI之路”。


深度学习框架训练模型时的代码主要包含数据读取、网络构建和其他设置三方面,基本上掌握这三方面就可以较为灵活地使用框架训练模型。PyTorch是Facebook的官方深度学习框架之一,开源以来,势头非常猛,相信使用过的人都会被其轻便和快速等特点深深吸引,因此本文从整体上介绍如何使用PyTorch。

 

PyTorch的官方github地址:

https://github.com/pytorch/pytorch 

PyTorch官方文档:

http://pytorch.org/docs/0.3.0/


接下来就按照上述的3个方面来介绍如何使用PyTorch。


一、数据读取

数据读取部分包含如何将你的图像和标签数据转换成PyTorch框架的Tensor数据类型,官方代码库中有一个接口例子:torchvision.ImageFolder。因为这个接口针对的数据存放方式是每个文件夹包含一个类的图像,但是实际应用中可能你的数据不是这样维护的,或者你的数据是多标签的,或者其他更复杂的形式,那么就需要自定义一个数据读取接口,这个时候就不得不提一个PyTorch中数据读取基类:torch.utils.data.Dataset,包括前面提到的torchvision.ImageFolder接口的对应类也是继承torch.utils.data.Dataset实现的,因此torch.utils.data.Dataset类是PyTorch框架中数据读取的核心。


在自定义数据读取接口时还有一步很重要的操作:数据预处理。常常我们在论文中看到的data argumentation就是指的数据预处理,对实验结果影响还是比较大的。该操作在PyTorch中可以通过torchvision.transforms接口来实现。


经过上述的两个操作后,还需再进行一次封装,将数据和标签封装成数据迭代器,这样才方便模型训练的时候一个batch一个batch地进行,这就要用到torch.utils.data.DataLoader接口,该接口的一个输入就是前面继承自torch.utils.data.Dataset类的自定义了的对象(比如torchvision.ImageFolder类的对象)。


至此,从图像和标签文件就生成了Tensor类型的数据迭代器,后续仅需将Tensor对象用torch.autograd.Variable接口封装成Variable类型(比如

train_data=torch.autograd.Variable(train_data),如果要在gpu上运行则是:train_data=torch.autograd.Variable(train_data.cuda()))就可以作为模型的输入了。


其他自定义的数据读取接口例子可以参考:https://github.com/miraclewkf/MobileNetV2-PyTorch,该项目中的read_ImageNetData.py脚本自定义了读取ImageNet数据集的接口,训练数据的读取和验证数据的读取采取不同的接口实现,比较有特点。


二、网络构建

PyTorch框架中提供了一些方便使用的网络结构及预训练模型接口:torchvision.models。该接口可以直接导入指定的网络结构,并且可以选择是否用预训练模型初始化导入的网络结构。


那么如何自定义网络结构呢?在PyTorch中,构建网络结构的类都是基于torch.nn.Module这个基类进行的,也就是说所有网络结构的构建都可以通过继承该类来实现,包括torchvision.models接口中的模型实现类也是继承这个基类进行重写的。自定义网络结构可以参考:

1、https://github.com/miraclewkf/MobileNetV2-PyTorch。该项目中的MobileNetV2.py脚本自定义了网络结构。

2、https://github.com/miraclewkf/SENet-PyTorch。该项目中的se_resnet.py和se_resnext.py脚本分别自定义了不同的网络结构。


如果要用某预训练模型为自定义的网络结构进行参数初始化,可以用torch.load接口导入预训练模型,然后调用自定义的网络结构对象的load_state_dict方式进行参数初始化,具体可以看https://github.com/miraclewkf/MobileNetV2-PyTorch项目中的train.py脚本中if args.resume条件语句。


三、其他设置

优化函数通过torch.optim包实现,比如torch.optim.SGD()接口表示随机梯度下降。更多优化函数可以看官方文档:

http://pytorch.org/docs/0.3.0/optim.html。


学习率策略通过torch.optim.lr_scheduler接口实现,比如torch.optim.lr_scheduler.StepLR()接口表示按指定epoch数减少学习率。更多学习率变化策略可以看官方文档:

http://pytorch.org/docs/0.3.0/optim.html。


损失函数通过torch.nn包实现,比如torch.nn.CrossEntropyLoss()接口表示交叉熵等。


多GPU训练通过torch.nn.DataParallel接口实现,比如:

model = torch.nn.DataParallel(model, device_ids=[0,1])表示在gpu0和1上训练模型。




*延伸阅读



点击左下角阅读原文”,即可申请加入极市目标跟踪、目标检测、工业检测、人脸方向、视觉竞赛等技术交流群,更有每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流,一起来让思想之光照的更远吧~



觉得有用麻烦给个好看啦~  

登录查看更多
5

相关内容

【干货书】现代数据平台架构,636页pdf
专知会员服务
256+阅读 · 2020年6月15日
【Amazon】使用预先训练的Transformer模型进行数据增强
专知会员服务
57+阅读 · 2020年3月6日
《强化学习—使用 Open AI、TensorFlow和Keras实现》174页pdf
专知会员服务
139+阅读 · 2020年3月1日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
《动手学深度学习》(Dive into Deep Learning)PyTorch实现
专知会员服务
120+阅读 · 2019年12月31日
【干货】用BRET进行多标签文本分类(附代码)
专知会员服务
85+阅读 · 2019年12月27日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
165+阅读 · 2019年10月28日
PyTorch  深度学习新手入门指南
机器学习算法与Python学习
9+阅读 · 2019年9月16日
开发 | PyTorch好助手:PyTorch Hub一键复现各路模型
PyTorch 学习笔记(一):让PyTorch读取你的数据集
极市平台
16+阅读 · 2019年4月24日
用PyTorch做物体检测和追踪
AI研习社
12+阅读 · 2019年1月6日
教程 | 从头开始了解PyTorch的简单实现
机器之心
20+阅读 · 2018年4月11日
从基础概念到实现,小白如何快速入门PyTorch
机器之心
13+阅读 · 2018年2月26日
A survey on deep hashing for image retrieval
Arxiv
14+阅读 · 2020年6月10日
Paraphrase Generation with Deep Reinforcement Learning
VIP会员
相关资讯
PyTorch  深度学习新手入门指南
机器学习算法与Python学习
9+阅读 · 2019年9月16日
开发 | PyTorch好助手:PyTorch Hub一键复现各路模型
PyTorch 学习笔记(一):让PyTorch读取你的数据集
极市平台
16+阅读 · 2019年4月24日
用PyTorch做物体检测和追踪
AI研习社
12+阅读 · 2019年1月6日
教程 | 从头开始了解PyTorch的简单实现
机器之心
20+阅读 · 2018年4月11日
从基础概念到实现,小白如何快速入门PyTorch
机器之心
13+阅读 · 2018年2月26日
Top
微信扫码咨询专知VIP会员