【干货】深度学习实验流程及PyTorch提供的解决方案

2018 年 2 月 7 日 专知 huaiwen

【导读】近日,专知小组博士生huaiwen创作了一系列PyTorch实战教程,致力于介绍如何用PyTorch实践你的科研想法。今天推出其创作的第一篇《深度学习实验流程及PyTorch提供的解决方案》。在研究深度学习的过程中,当你脑中突然迸发出一个灵感,你是否发现没有趁手的工具可以快速实现你的想法?看完本文之后,你可能会多出一个选择。本文简要的分析了研究深度学习问题时常见的工作流, 并介绍了怎么使用PyTorch来快速构建你的实验。如果本文能为您的科研道路提供一丝便捷,我们将不胜荣幸。


专知公众号以前推出PyTorch手把手系列教程:

    【教程】专知-PyTorch手把手深度学习教程系列完整版

      势头强劲: PyTorch周年大事记盘点

常见的Research workflow




某一天, 你坐在实验室的椅子上, 突然:

  • 你脑子里迸发出一个idea

  • 你看了关于某一theory的文章, 想试试: 要是把xx也加进去会怎么样

  • 你老板突然给你一张纸, 然后说: 那个谁, 来把这个东西实现一下

于是, 你设计实验流程, 并为这一idea 挑选了合适的数据集和运行环境, 然后你废寝忘食的实现模型, 经过长时间的训练和测试, 你发现:

  • 这idea不work  --> 那算了 or 再调调

  • 这idea很work  --> 可以paper了

我们可以把上述流程用下图表示:


实际上, 常见的流程由下面几项组成起来:


  1.  一旦选定了数据集, 你就要写一些函数去load 数据集, 然后pre-process数据集, normalize 数据集, 可以说这是一个实验中占比重最多的部分, 因为:

    1. 每个数据集的格式都不太一样

    2. 预处理和正则化的方式也不尽相同

    3. 需要一个快速的dataloader 来feed data, 越快越好

  2. 然后, 你就要实现自己的模型, 如果你是CV方向的你可能想实现一个ResNet,如果你是NLP相关的你可能想实现一个Seq2Seq

  3. 接下来, 你需要实现训练步骤, 分batch, 循环epoch

  4. 在若干轮的训练后, 总要checkpoint一下, 才是最安全的

  5. 你还需要构建一些baseline,以验证自己idea的有效性

  6. 如果你实现的是神经网络模型, 当然离不开GPU的支持

  7. 很多深度学习框架提供了常见的损失函数, 但大部分时间, 损失函数都要和具体任务结合起来, 然后重新实现

  8. 使用优化方法, 优化构建的模型, 动态调整学习率


Pytorch 给出的解决方案




对于加载数据, Pytorch提出了多种解决办法


  • Pytorch 是一个Python包,而不是某些大型C++库的Python 接口, 所以, 对于数据集本身提供Python API的, Pytorch 可以直接调用, 不必特殊处理.

  • Pytorch 集成了常用数据集的data loader

  • 虽然以上措施已经能涵盖大部分数据集了, 但Pytorch还开展了两个项目: vision, 和text, 见下图灰色背景部分.  这两个项目, 采用众包机制, 收集了大量的dataloader, pre-process 以及 normalize, 分别对应于图像和文本信息.


  • 如果你要自定义数据集,也只需要继承torch.utils.data.dataset

对于构建模型, Pytorch也提供了三种方案


  • 众包的模型: torch.utils.model_zoo , 你可以使用这个工具, 加载大家共享出来的模型

  • 使用torch.nn.Sequential 模块快速构建

net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
   
torch.nn.ReLU(),
   
torch.nn.Linear(10, 1)
)
print(net)
'''
Sequential (
 (0): Linear (1 -> 10)
 (1): ReLU ()
 (2): Linear (10 -> 1)
)
'''
  • 集成torch.nn.Module 深度定制

class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)

def forward(self, x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x

net = Net(1, 10, 1)
print(net)
'''
Net (
 (hidden): Linear (1 -> 10)
 (predict): Linear (10 -> 1)
)
'''


对于训练过程的Pytorch实现


你当然可以自己实现数据的batch, shuffer等,但Pytorch 建议用类torch.utils.data.DataLoader加载数据,并对数据进行采样,生成batch迭代器。

# 创建数据加载器
loader = Data.DataLoader(
dataset=torch_dataset,      # TensorDataset类型数据集
   
batch_size=BATCH_SIZE,      # mini batch size
   
shuffle=True,               # 设置随机洗牌
   
num_workers=2,              # 加载数据的进程个数
)

for epoch in range(3): # 训练3轮
   
for step, (batch_x, batch_y) in enumerate(loader): # 每一步
       # 在这里写训练代码...
       
print('Epoch: ', epoch)


对于保存和加载模型Pytorch提供两种方案


  • 保存和加载整个网络

# 保存和加载整个模型, 包括: 网络结构, 模型参数等
torch.save(resnet, 'model.pkl')
model = torch.load('model.pkl')


  • 保存和加载网络中的参数

torch.save(resnet.state_dict(), 'params.pkl')
resnet.load_state_dict(torch.load('params.pkl'))


对于GPU支持


你可以直接调用Tensor的.cuda() 直接将Tensor的数据迁移到GPU的显存上, 当然, 你也可以用.cpu() 随时将数据移回内存

if torch.cuda.is_available():
linear = linear.cuda() # 将网络中的参数和缓存移到GPU显存中


对于Loss函数, 以及自定义Loss

在Pytorch的包torch.nn里, 不仅包含常用且经典的Loss函数, 还会实时跟进新的Loss 包括: CosineEmbeddingLoss, TripletMarginLoss等.


如果你的idea非常新颖, Pytorch提供了三种自定义Loss的方式


  • 继承torch.nn.module

import torch
import torch.nn as nn
import torch.nn.functional as func
class MyLoss(nn.Module):
# 设置超参
   
def __init__(self, a, b, c):
super(TripletLossFunc, self).__init__()
self.a = a
self.b = b
self.c = c
return

   def
forward(self, a, b, c):
# 具体实现
       
loss = a + b + c
return loss

然后

loss_instance = MyLoss(...)
loss = loss_instance(a, b, c)

这样做, 你能够用torch.nn.functional里优化过的各种函数来组成你的Loss


  • 继承torch.autograd.Function

import torch
from torch.autograd import Function
from torch.autograd import Variable
class MyLoss(Function):
def forward(input_tensor):
# 具体实现
       
result = ......
return torch.Tensor(result)

def backward(grad_output):
# 如果你只是需要在loss中应用这个操作的时候,这里直接return输入就可以了
       # 如果你需要在nn中用到这个,需要写明具体的反向传播操作
       
return grad_output

这样做,你能够用常用的numpy和scipy函数来组成你的Loss


  • 写一个Pytorch的C扩展

        这里就不细讲了,未来会有内容专门介绍这一部分。


对于优化算法以及调节学习率


Pytorch集成了常见的优化算法, 包括SGD, Adam, SparseAdam, AdagradRMSprop, Rprop等等.

torch.optim.lr_scheduler  提供了多种方式来基于epoch迭代次数调节学习率 torch.optim.lr_scheduler.ReduceLROnPlateau 还能够基于实时的学习结果, 动态调整学习率.


希望第一篇《深度学习实验流程及PyTorch提供的解决方案》,大家会喜欢,后续会推出系列实战教程,敬请期待。

-END-

专 · 知

人工智能领域主题知识资料查看获取【专知荟萃】人工智能领域26个主题知识资料全集(入门/进阶/论文/综述/视频/专家等)

请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料

请扫一扫如下二维码关注我们的公众号,获取人工智能的专业知识!

请加专知小助手微信(Rancho_Fang),加入专知主题人工智能群交流!

点击“阅读原文”,使用专知

登录查看更多
8

相关内容

华为发布《自动驾驶网络解决方案白皮书》
专知会员服务
126+阅读 · 2020年5月22日
一份循环神经网络RNNs简明教程,37页ppt
专知会员服务
173+阅读 · 2020年5月6日
深度学习自然语言处理概述,216页ppt,Jindřich Helcl
专知会员服务
214+阅读 · 2020年4月26日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【干货】用BRET进行多标签文本分类(附代码)
专知会员服务
85+阅读 · 2019年12月27日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
165+阅读 · 2019年10月28日
生成式对抗网络GAN异常检测
专知会员服务
117+阅读 · 2019年10月13日
网易云课堂独家 | 基于PyTorch实现的《深度学习》
深度学习与NLP
11+阅读 · 2019年2月15日
Forge:如何管理你的机器学习实验
专知
11+阅读 · 2018年12月1日
干货——图像分类(下)
计算机视觉战队
14+阅读 · 2018年8月28日
手把手教 | 深度学习库PyTorch(附代码)
数据派THU
27+阅读 · 2018年3月15日
6个实验教你用Torch玩转深度学习
七月在线实验室
7+阅读 · 2017年11月21日
wGAN如何解决GAN已有问题(附代码实现)
数据派THU
17+阅读 · 2017年6月27日
RNN | RNN实践指南(3)
KingsGarden
7+阅读 · 2017年6月5日
Optimization for deep learning: theory and algorithms
Arxiv
105+阅读 · 2019年12月19日
Arxiv
6+阅读 · 2018年2月24日
VIP会员
相关VIP内容
华为发布《自动驾驶网络解决方案白皮书》
专知会员服务
126+阅读 · 2020年5月22日
一份循环神经网络RNNs简明教程,37页ppt
专知会员服务
173+阅读 · 2020年5月6日
深度学习自然语言处理概述,216页ppt,Jindřich Helcl
专知会员服务
214+阅读 · 2020年4月26日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【干货】用BRET进行多标签文本分类(附代码)
专知会员服务
85+阅读 · 2019年12月27日
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
165+阅读 · 2019年10月28日
生成式对抗网络GAN异常检测
专知会员服务
117+阅读 · 2019年10月13日
相关资讯
网易云课堂独家 | 基于PyTorch实现的《深度学习》
深度学习与NLP
11+阅读 · 2019年2月15日
Forge:如何管理你的机器学习实验
专知
11+阅读 · 2018年12月1日
干货——图像分类(下)
计算机视觉战队
14+阅读 · 2018年8月28日
手把手教 | 深度学习库PyTorch(附代码)
数据派THU
27+阅读 · 2018年3月15日
6个实验教你用Torch玩转深度学习
七月在线实验室
7+阅读 · 2017年11月21日
wGAN如何解决GAN已有问题(附代码实现)
数据派THU
17+阅读 · 2017年6月27日
RNN | RNN实践指南(3)
KingsGarden
7+阅读 · 2017年6月5日
Top
微信扫码咨询专知VIP会员