PyTorch 学习笔记(五):Finetune和各层定制学习率

2019 年 5 月 5 日 极市平台

加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发者互动交流!更有机会与李开复老师等大牛群内互动!

同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流点击文末“阅读原文”立刻申请入群~


作者 | 余霆嵩

来源专栏 | PyTorch学习笔记


本文截取自一个github上千星的火爆教程——《PyTorch 模型训练实用教程》教程内容主要为在 PyTorch 中训练一个模型所可能涉及到的方法及函数的详解等,本文为作者整理的学习笔记(五),后续会继续更新这个系列,欢迎关注。

项目代码:https://github.com/tensor-yu/PyTorch_Tutorial


系列回顾:



我们知道一个良好的权值初始化,可以使收敛速度加快,甚至可以获得更好的精度。而在实际应用中,我们通常采用一个已经训练模型的模型的权值参数作为我们模型的初始化参数,也称之为Finetune,更宽泛的称之为迁移学习。迁移学习中的Finetune技术,本质上就是让我们新构建的模型,拥有一个较好的权值初始值。


finetune权值初始化三步曲,finetune就相当于给模型进行初始化,其流程共用三步:


第一步:保存模型,拥有一个预训练模型; 第二步:加载模型,把预训练模型中的权值取出来; 第三步:初始化,将权值对应的“放”到新模型中


一、Finetune之权值初始化

在进行finetune之前我们需要拥有一个模型或者是模型参数,因此需要了解如何保存模型。官方文档中介绍了两种保存模型的方法,一种是保存整个模型,另外一种是仅保存模型参数(官方推荐用这种方法),这里采用官方推荐的方法。


第一步:保存模型参数

若拥有模型参数,可跳过这一步。假设创建了一个net = Net(),并且经过训练,通过以下方式保存:torch.save(net.state_dict(), 'net_params.pkl')


第二步:加载模型

进行三步曲中的第二步,加载模型,这里只是加载模型的参数:pretrained_dict = torch.load('net_params.pkl')


第三步:初始化

进行三步曲中的第三步,将取到的权值,对应的放到新模型中:首先我们创建新模型,并且获取新模型的参数字典net_state_dict:net = Net() # 创建netnet_state_dict = net.state_dict() # 获取已创建net的state_dict
接着将pretrained_dict里不属于net_state_dict的键剔除掉:pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}
然后,用预训练模型的参数字典 对 新模型的参数字典net_state_dict 进行更新:net_state_dict.update(pretrained_dict_1)
最后,将更新了参数的字典 “放”回到网络中:net.load_state_dict(net_state_dict)


这样,利用预训练模型参数对新模型的权值进行初始化过程就做完了。

采用finetune的训练过程中,有时候希望前面层的学习率低一些,改变不要太大,而后面的全连接层的学习率相对大一些。这时就需要对不同的层设置不同的学习率,下面就介绍如何为不同层配置不同的学习率。


二、不同层设置不同的学习率

在利用pre-trained model的参数做初始化之后,我们可能想让fc层更新相对快一些,而希望前面的权值更新小一些,这就可以通过为不同的层设置不同的学习率来达到此目的。


为不同层设置不同的学习率,主要通过优化器对多个参数组进行设置不同的参数。所以,只需要将原始的参数组,划分成两个,甚至更多的参数组,然后分别进行设置学习率。 这里将原始参数“切分”成fc3层参数和其余参数,为fc3层设置更大的学习率。


请看代码:

ignored_params = list(map(id, net.fc3.parameters())) # 返回的是parameters的 内存地址base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) optimizer = optim.SGD([{'params': base_params},{'params': net.fc3.parameters(), 'lr': 0.001*10}], 0.001, momentum=0.9, weight_decay=1e-4)


第一行+ 第二行的意思就是,将fc3层的参数net.fc3.parameters()从原始参数net.parameters()中剥离出来 base_params就是剥离了fc3层的参数的其余参数,然后在优化器中为fc3层的参数单独设定学习率。


optimizer = optim.SGD(......)这里的意思就是 base_params中的层,用 0.001, momentum=0.9, weight_decay=1e-4 fc3层设定学习率为: 0.001*10


完整代码位于 :

https://github.com/tensor-yu/PyTorch_Tutorial/blob/master/Code/2_model/2_finetune.py


补充:

挑选出特定的层的机制是利用内存地址作为过滤条件,将需要单独设定的那部分参数,从总的参数中剔除。 base_params 是一个list,每个元素是一个Parameter 类 net.fc3.parameters() 是一个


ignored_params = list(map(id, net.fc3.parameters())) net.fc3.parameters() 是一个 所以迭代的返回其中的parameter,这里有weight 和 bias 最终返回weight和bias所在内存的地址






*延伸阅读



点击左下角阅读原文”,即可申请加入极市目标跟踪、目标检测、工业检测、人脸方向、视觉竞赛等技术交流群,更有每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流,一起来让思想之光照的更远吧~


觉得有用麻烦给个在看啦~  

登录查看更多
16

相关内容

一份简短《图神经网络GNN》笔记,入门小册
专知会员服务
224+阅读 · 2020年4月11日
简明扼要!Python教程手册,206页pdf
专知会员服务
47+阅读 · 2020年3月24日
专知会员服务
44+阅读 · 2020年3月6日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【模型泛化教程】标签平滑与Keras, TensorFlow,和深度学习
专知会员服务
20+阅读 · 2019年12月31日
【文档】PyTorch中文版官方教程来了...
机器学习算法与Python学习
6+阅读 · 2019年9月8日
PyTorch模型训练特征图可视化(TensorboardX)
极市平台
33+阅读 · 2019年6月29日
PyTorch 学习笔记(七):PyTorch的十个优化器
极市平台
8+阅读 · 2019年5月19日
PyTorch 学习笔记(六):PyTorch的十七个损失函数
极市平台
47+阅读 · 2019年5月13日
PyTorch 学习笔记(四):权值初始化的十种方法
极市平台
14+阅读 · 2019年5月1日
PyTorch 学习笔记(三):transforms的二十二个方法
极市平台
12+阅读 · 2019年4月28日
PyTorch 学习笔记(一):让PyTorch读取你的数据集
极市平台
16+阅读 · 2019年4月24日
快速上手笔记,PyTorch模型训练实用教程(附代码)
教程 | PyTorch经验指南:技巧与陷阱
机器之心
15+阅读 · 2018年7月30日
Arxiv
5+阅读 · 2020年3月26日
Adversarial Reprogramming of Neural Networks
Arxiv
3+阅读 · 2018年6月28日
Arxiv
7+阅读 · 2018年3月22日
Arxiv
4+阅读 · 2017年7月25日
Arxiv
5+阅读 · 2017年7月23日
VIP会员
相关资讯
【文档】PyTorch中文版官方教程来了...
机器学习算法与Python学习
6+阅读 · 2019年9月8日
PyTorch模型训练特征图可视化(TensorboardX)
极市平台
33+阅读 · 2019年6月29日
PyTorch 学习笔记(七):PyTorch的十个优化器
极市平台
8+阅读 · 2019年5月19日
PyTorch 学习笔记(六):PyTorch的十七个损失函数
极市平台
47+阅读 · 2019年5月13日
PyTorch 学习笔记(四):权值初始化的十种方法
极市平台
14+阅读 · 2019年5月1日
PyTorch 学习笔记(三):transforms的二十二个方法
极市平台
12+阅读 · 2019年4月28日
PyTorch 学习笔记(一):让PyTorch读取你的数据集
极市平台
16+阅读 · 2019年4月24日
快速上手笔记,PyTorch模型训练实用教程(附代码)
教程 | PyTorch经验指南:技巧与陷阱
机器之心
15+阅读 · 2018年7月30日
相关论文
Arxiv
5+阅读 · 2020年3月26日
Adversarial Reprogramming of Neural Networks
Arxiv
3+阅读 · 2018年6月28日
Arxiv
7+阅读 · 2018年3月22日
Arxiv
4+阅读 · 2017年7月25日
Arxiv
5+阅读 · 2017年7月23日
Top
微信扫码咨询专知VIP会员