这篇文章来自于旷视。旷视内部有一个基础模型组,孙剑老师也是很看好 NAS 相关的技术,相信这篇文章无论从学术上还是工程落地上都有可以让人借鉴的地方。回到文章本身,模型剪枝算法能够减少模型计算量,实现模型压缩和加速的目的,但是模型剪枝过程中确定剪枝比例等参数的过程实在让人头痛。
这篇文章提出了 PruningNet 的概念,自动为剪枝后的模型生成权重,从而绕过了费时的 retrain 步骤。并且能够和进化算法等搜索方法结合,通过搜索编码 network 的 coding vector,自动地根据所给约束搜索剪枝后的网络结构。和 AutoML 技术相比,这种方法并不是从头搜索,而是从已有的大模型出发,从而缩小了搜索空间,节省了搜索算力和时间。
个人觉得这种剪枝和 NAS 结合的方法,应该会在以后吸引越来越多人的注意。这篇文章的代码已经开源在了 Github:
模型剪枝是一种能够减少模型大小和计算量的方法。模型剪枝一般可以分为三个步骤:
训练一个参数量较多的大网络
将不重要的权重参数剪掉
剪枝后的小网络做 fine tune
其中第二步是模型剪枝中的关键。有很多 paper 围绕“怎么判断权重是否重要”以及“如何剪枝”等问题进行讨论。困扰模型剪枝落地的一个问题就是剪枝比例的确定。
传统的剪枝方法常常需要人工 layer by layer 地去确定每层的剪枝比例,然后进行 fine tune,用起来很耗时,而且很不方便。不过最近的 Rethinking the Value of Network Pruning [1] 指出,剪枝后的权重并不重要,对于 channel pruning 来说,更重要的是找到剪枝后的网络结构,具体来说就是每层留下的 channel 数量。
受这个发现启发,文章提出可以用一个 PruningNet,对于给定的剪枝网络,自动生成 weight,无需进行 retrain,然后评测剪枝网络在验证集上的性能,从而选出最优的网络结构。
具体来说,PruningNet 的输入是剪枝后的网络结构,必须首先对网络结构进行编码,转换为 coding vector。这里可以直接用剪枝后网络每层的 channel 数来编码。在搜索剪枝网络的时候,我们可以尝试各种 coding vector,用 PruningNet 生成剪枝后的网络权重。网络结构和权重都有了,就可以去评测网络的性能。进而用进化算法搜索最优的 coding vector,也就是最优的剪枝结构。在用进化算法搜索的时候,可以使用自定义的目标函数,包括将网络的 accuracy,latency,FLOPS 等考虑进来。
从上一小节已经可以知道,PruningNet 是整个算法的关键。那么怎么才能找到这样一个“神奇网络”呢?
先做一下符号约定,使用 ci 表示剪枝之后第 i 层的 channel 数量, l 为网络的层数, W 表示剪枝后网络的权重。那么 PruningNet 的输入输出如下所示:
训练
先结合下图看一下 forward 部分。PruningNet 是由 l 个 PruningBlock 组成的,每个 PruningBlock 是一个两层的 MLP。
首先看图 b,编码着网络结构信息的 coding vector 输入到当前 block 后,输出经过 Reshape,成了一个 Weight Matrix。注意哦,这里的 WeightMatrix 是固定大小的(也就是未剪枝的原始 Weight shape 大小),和剪枝网络结构无关。
再看图 a,因为要对网络进行剪枝,所以 WeightMatrix 要进行 Crop。对应到图 b,可以看到,Crop 是在两个维度上进行的。首先,由于上一层也进行了剪枝,所以 input channel 数变少了;其次,由于当前层进行了剪枝,所以 output channel 数变少了。这样经过 Crop,就生成了剪枝后的网络 weight。我们再输入一个 mini batch 的训练图片,就可以得到剪枝后的网络的 loss。
在 backward 部分,我们不更新剪枝后网络的权重,而是更新 PruningNet 的权重。由于上面的操作都是可微分的,所以直接用链式法则传过去就行。如果你使用 PyTorch 等支持自动微分的框架,这是很容易的。
下图所示是训练过程的整个 PruningNet(左侧)和剪枝后网络(右侧,即 PrunedNet)。训练过程中的 coding vector 在状态空间里随机采样,随机选取每层的 channel 数量。
PS:和原始论文相比,下图和上图顺序是颠倒的。这里从底向上介绍了 PruningNet 的训练,而论文则是自顶向下。
搜索
训练好 PruningNet 后,就可以用它来进行搜索了!我们只需要输入某个 coding vector,PruningNet 就会为我们生成对应每层的 WeightMatrix。别忘了 coding vector 是编码的网络结构,现在又有了 weight,我们就可以在验证集上测试网络的性能了。进而,可以使用进化算法等优化方法去搜索最优的 coding vector。当我们得到了最优结构的剪枝网络后,再 from scratch 地训练它。
进化算法这里不再赘述,很多优化的书中包括网上都有资料。这里把整个算法流程贴出来:
作者在 ImageNet 上用 MobileNet 和 ResNet 进行了实验。训练 PruningNet 用了 1/4 的原模型的 epochs。数据增强使用常见的标准流程,输入 image 大小为 224×224。
将原始 ImageNet 的训练集做分割,每个类别选 50 张组成 sub-validation(共计 50000),其余作为 sub-training。在训练时,我们使用 sub-training 训练 PruningNet。在搜索时,使用 sub-validation 评估剪枝网络的性能。不过,还要注意,在搜索时,使用 20000 张 sub-training 中的图片重新计算 BatchNorm layer 中的 running mean 和 running variance。
shortcut 剪枝
在进行模型剪枝时,一个比较难处理的问题是 ResNet 中的 shortcut 结构。因为最后有一个 element-wise 的相加操作,必须保证两路 feature map 是严格 shape 相同的,所以不能随意剪枝,否则会造成 channel 不匹配。下面对几种论文中用到的网络结构分别讨论。
MobileNet-v1 是没有 shortcut 结构的。我们为每个 conv layer 都配上相应的 PruningBlock——一个两层的 MLP。PruningNet 的输入 coding vector 中的元素是剪枝后每层的 channel 数量。而输入第 i 个 PruningBlock 的是一个 2D vector,由归一化的第 i-1 层和第 i 层的剪枝比例构成。这部分可以结合代码来看:
注意第 1 个 conv layer 的输入是 1D vector,因为它是第一个被剪枝的 layer。在训练时,coding vector 的搜索空间被以一定步长划分为 grid,采样就是在这些格点上进行的。
MobileNet-v2
MobileNet-v2 引入了类似 ResNet 的 shortcut 结构,这种 resnet block 必须统一看待。具体来说,对于没有在 resnet block 中的conv,处理方法如 MobileNet-v1。对每个 resnet block,配上一个相应的 PruningBlock。由于每个 resnet block 中只有一个中间层(3×3 的 conv),所以输出第 i 个 PruningBlock 的是一个 3D vector,由归一化的第 i-1 个 resnet block,第 i 个 resnet block 和中间 conv 层的剪枝比例构成。其他设置和 MobileNet-v1 相同。这里可以结合代码来看:
ResNet
处理方法如 MobileNet-v2 所示。可以结合代码来看:
实验结果
在相近 FLOPS 情况下,和 MobileNet 论文中改变 ratio 参数得到的模型比较,MetaPruning 得到的模型 accuracy 更高。尤其是压缩比例更大时,该方法更有优势。
和其他剪枝方法(如 AMC [2])等方法比较,该方法也得到了 SOTA 的结果。MetaPruning 方法能够以一种统一的方法处理 ResNet 中的 shortcut 结构,并且不需要人工调整太多的参数。
上面的比较都是基于理论 FLOPS,现在更多人在关注网络在实际硬件上的 latency 怎么样。文章对此也进行了讨论。如何测试网络的 latency?
当然可以每个网络都实际跑一下,不过有些麻烦。基于每个 layer 的 inference 时间是互相独立的这个假设,作者首先构造了各个 layer inference latency 的查找表(参见论文 Fbnet: Hardware-aware efficient convnet design via differentiable neural architecture search [3]),以此来估计实际网络的 latency。作者这里和 MobileNet baseline 做了比较,结果也证明了该方法更优。
PruningNet 结果分析
此外,作者还对 PruningNet 的预测结果进行可视化,试图找出一些可解释性,并找出剪枝参数的一些规律。
down-sampling 的部分 PruningNet 倾向于保留更多的 channel,如 MobileNet-v2 block 中间的那个 conv;
优先剪浅层 layer 的 channel,FLOPS 约束太强剪深层的 channel,但可能会造成网络 accuracy 下降比较多。
这篇论文把剪枝算法和 NAS 结合,取两者之长,用待剪枝的模型缩小了搜索空间,用进化算法自动搜索最优网络结构。使用 coding vector 编码网络结构,用一个很简单的双隐层感知机预测网络权重,并提出了一种 shortcut 的处理方法,在 ImageNet 数据集和几种常用网络结构上取得了不错的结果。文章提出的方法简单易于操作,可以很方便地应用到自己的业务场景中。相关代码已经开源在 Github 上。
[1] https://arxiv.org/abs/1810.05270
[2] https://arxiv.org/abs/1802.03494
[3] https://arxiv.org/abs/1812.03443
点击以下标题查看更多往期内容:
#投 稿 通 道#
让你的论文被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得或技术干货。我们的目的只有一个,让知识真正流动起来。
📝 来稿标准:
• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)
• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接
• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志
📬 投稿邮箱:
• 投稿邮箱:hr@paperweekly.site
• 所有文章配图,请单独在附件中发送
• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。
▽ 点击 | 阅读原文 | 下载论文 & 源码