基于PyTorch的「Keras」:除了核心逻辑通通都封装

2019 年 8 月 5 日 机器之心

机器之心报道

参与:思源、一鸣

Keras 和 PyTorch 都是对初学者最友好的深度学习框架,它们用起来就像描述架构的简单语言一样,告诉框架哪一层该用什么就行。很多研究者和开发者都在考虑到底哪一个框架更好,但目前两个框架都非常流行,它们都各有优势。最近,Facebook 研究员 William Falcon 为 PyTorch 披上了一件 Keras 的外衣,他表明用这样的框架做研究简直不要太爽。


PyTorch Lightning 地址:https://github.com/williamFalcon/pytorch-lightning



看起来像 Keras 的 PyTorch


Keras 本身的目的就是对深度学习框架(TensorFlow、Theano)进行了进一步的 API 封装。作为 TensorFlow 的高度封装,Keras 的抽象层次非常高,很多 API 细节都隐藏了起来。虽然 PyTorch 同样使用动态计算图,也方便快捷,但总体上 Keras 隐藏的细节更多一些。


反观 PyTorch,它提供一个相对较低级别的实验环境,使用户可以更加自由地编写自定义层、查看数值优化任务等等。例如在 PyTorch 1.0 中,编译工具 torch.jit 就包含一种名为 Torch Script 的语言,它是 Python 的子语言,开发者使用它能进一步对模型进行优化。


用 PyTorch 写模型,除了数据加载和模型定义部分外,整个训练和验证的逻辑、配置都需要我们手动完成,这些步骤都较为繁琐。甚至可以说,研究者需要耗费相当多的精力处理这一部分的代码,还要祈祷不出 Bug。但是对于大多数研究实验来说,训练和验证的循环体差不多都是一样的,实现的功能也相当一致,所以为什么不将这些通用的东西都打包在一起,这样训练不就简单了么?


William Falcon 正是这样想的,他将 PyTorch 开发中的各种通用配置全都包装起来,我们只需要写核心逻辑就行。通过 PyTorch Lightning,PyTorch 就类似于 Keras,它能以更高级的形式快速搭建模型。


项目作者是谁


要完成这样的工作,工作量肯定是非常大的,因为从超参搜索、模型 Debug、分布式训练、训练和验证的循环逻辑到模型日志的打印,都需要写一套通用的方案,确保各种任务都能用得上。所以 Facebook 的这位小哥哥 William Falcon 还是很厉害的。




他是一位 NYU 和 Facebook 的开发者。目前在 NYU 攻读 PhD。从 GitHub 的活动来看,小哥是一位比较活跃的开发者。



这是一位披着 Keras 外衣的 PyTorch


Lightning 是 PyTorch 非常轻量级的包装,研究者只需要写最核心的训练和验证逻辑,其它过程都会自动完成。因此这就有点类似 Keras 那种高级包装,它隐藏了绝大多数细节,只保留了最通俗易懂的接口。Lightning 能确保自动完成部分的正确性,对于核心训练逻辑的提炼非常有优势。


那么我们为什么要用 Lightning?


当我们开始构建新项目,最后你希望做的可能就是记录训练循环、多集群训练、float16 精度、提前终止、模型加载/保存等等。这一系列过程可能需要花很多精力来解决各式各样、千奇百怪的 Bug,因此很难把精力都放在研究的核心逻辑上。


通过使用 Lightning,这些部分都能保证是 Work 的,因此能抽出精力关注我们要研究的东西:数据、训练、验证逻辑。此外,我们完全不需要担心使用多 GPU 加速会很难,因为 Lightning 会把这些东西都做好。


所以 Lightning 都能帮我们干什么?


下图展示了构建一个机器学习模型都会经历哪些过程,很多时候最困难的还不是写模型,是各种配置与预处理过程。如下蓝色的部分需要用 LightningModule 定义,而灰色部分 Lightning 可以自动完成。我们需要做的,差不多也就加载数据、定义模型、确定训练和验证过程。



下面的伪代码展示了大致需要定义的几大模块,它们再加上模型架构定义就能成为完整的模型。


# what to do in the training loopdef training_step(self, data_batch, batch_nb):# what to do in the validation loopdef validation_step(self, data_batch, batch_nb):
# how to aggregate validation_step outputsdef validation_end(self, outputs):
# and your dataloadersdef tng_dataloader():def val_dataloader():def test_dataloader():


除了需要定义的模块外,以下步骤均可通过 Lightning 自动完成。当然,每个模块可以单独进行配置。




Lightning 怎么用


Lightning 的使用也非常简单,只需要两步就能完成:定义 LightningModel;拟合训练器。


以经典的 MNIST 图像识别为例,如下展示了 LightningModel 的示例。我们可以照常导入 PyTorch 模块,但这次不是继承 nn.Module,而是继承 LightningModel。然后我们只需要照常写 PyTorch 就行了,该调用函数还是继续调用。这里看上去似乎没什么不同,但注意方法名都是确定的,这样才能利用 Lightning 的后续过程。


import osimport torchfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom torchvision.datasets import MNISTimport torchvision.transforms as transforms
import pytorch_lightning as ptl
class CoolModel(ptl.LightningModule): def __init__(self): super(CoolModel, self).__init__()    # not the best model...    self.l1 = torch.nn.Linear(28 * 2810)
  def forward(self, x):    return torch.relu(self.l1(x.view(x.size(0), -1)))   def my_loss(self, y_hat, y):   return F.cross_entropy(y_hat, y)   def training_step(self, batch, batch_nb):   x, y = batch   y_hat = self.forward(x)   return {'loss': self.my_loss(y_hat, y)}   def validation_step(self, batch, batch_nb):    x, y = batch   y_hat = self.forward(x)   return {'val_loss': self.my_loss(y_hat, y)}   def validation_end(self, outputs):   avg_loss = torch.stack([x['val_loss'for x in outputs]).mean()   return {'avg_val_loss': avg_loss}    def configure_optimizers(self):   return [torch.optim.Adam(self.parameters(), lr=0.02)]
@ptl.data_loader def tng_dataloader(self): return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@ptl.data_loader def val_dataloader(self): return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@ptl.data_loader def test_dataloader(self): return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)


随后,第二步即拟合训练器。这就比较类似 Keras 这类高级包装了,它将训练配置细节、循环体、以及日志输出等更加具体的信息全都隐藏了,一个 fit() 方法就能自动搞定一切。


这相比以前写 PyTorch 更加便捷精炼一些,而且分布式训练也非常容易,只要给出设备 id 就行了。


from pytorch_lightning import Trainerfrom test_tube import Experiment

model = CoolModel()exp = Experiment(save_dir=os.getcwd())
# train on cpu using only 10% of the data (for demo purposes)trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)
# train on 4 gpus# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3])# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3, 4, 5, 6, 7], nb_gpu_nodes=4)# train (1 epoch only here for demo)trainer.fit(model)# view tensorflow logs print(f'View tensorboard logs by running\ntensorboard --logdir {os.getcwd()}')print('and going to http://localhost:6006 on your browser')


其他特性


Pytorch-Lightning 还可以和 TensorBoard 无缝对接。



只需要定义运行的路径:


from test_tube import Experimentfrom pytorch-lightning import Trainer
exp = Experiment(save_dir = 『/some/path』)trainer = Trainer(experiment = exp)


将 TensorBoard 连接到路径上即可:


tensorboard -logdir /some/path


8月13日晚,腾讯将在澳门 IJCAI 2019 大会期间举办腾讯学术工业交流会(TAIC),诚邀AI从业者前来参加,共同探讨 AI 的应用与未来发展。点击「阅读原文」了解详情并参与报名。



登录查看更多
0

相关内容

【实用书】Python机器学习Scikit-Learn应用指南,247页pdf
专知会员服务
266+阅读 · 2020年6月10日
【WWW2020】DGL深度图神经网络实战教程,PPT+代码
专知会员服务
175+阅读 · 2020年4月12日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
164+阅读 · 2019年10月28日
盘一盘 Python 系列 10 - Keras (上)
平均机器
5+阅读 · 2019年8月26日
开发 | PyTorch好助手:PyTorch Hub一键复现各路模型
PyTorch 1.0 稳定版正式发布!
新智元
3+阅读 · 2018年12月8日
TensorFlow 2.0和PyTorch谁更好?大牛们争了好几天
Github 项目推荐 | 用 PyTorch 0.4 实现的 YoloV3
AI研习社
9+阅读 · 2018年8月11日
重磅 | PyTorch 0.4.0和官方升级指南来了!
AI前线
3+阅读 · 2018年4月25日
干货|7步掌握基于Keras的深度学习!
全球人工智能
4+阅读 · 2017年11月14日
Arxiv
3+阅读 · 2019年9月5日
Arxiv
5+阅读 · 2019年8月22日
Arxiv
8+阅读 · 2018年6月19日
Arxiv
3+阅读 · 2018年3月2日
VIP会员
相关资讯
盘一盘 Python 系列 10 - Keras (上)
平均机器
5+阅读 · 2019年8月26日
开发 | PyTorch好助手:PyTorch Hub一键复现各路模型
PyTorch 1.0 稳定版正式发布!
新智元
3+阅读 · 2018年12月8日
TensorFlow 2.0和PyTorch谁更好?大牛们争了好几天
Github 项目推荐 | 用 PyTorch 0.4 实现的 YoloV3
AI研习社
9+阅读 · 2018年8月11日
重磅 | PyTorch 0.4.0和官方升级指南来了!
AI前线
3+阅读 · 2018年4月25日
干货|7步掌握基于Keras的深度学习!
全球人工智能
4+阅读 · 2017年11月14日
Top
微信扫码咨询专知VIP会员