使用Batch Normalization折叠来加速模型推理

2020 年 11 月 13 日 极市平台
↑ 点击 蓝字  关注极市平台

作者丨Nathan Hubens
来源丨AI公园
编辑丨极市平台

极市导读

 

本文主要讲解如何去掉batch normalization层来加速神经网络。作者详细描述了在实践中使用Batch Normalization的流程,并展示了使用batch norm的VGG16,ResNet50两种架构的效果。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

介绍

Batch Normalization是将各层的输入进行归一化,使训练过程更快、更稳定的一种技术。在实践中,它是一个额外的层,我们通常添加在计算层之后,在非线性之前。它包括两个步骤:

  • 首先减去其平均值,然后除以其标准差
  • 进一步通过γ缩放,通过β偏移,这些是batch normalization层的参数,当网络不需要数据的时候,均值为0、标准差为1。

Batch normalization在神经网络的训练中具有较高的效率,因此得到了广泛的应用。但它在推理的时候有多少用处呢?


一旦训练结束,每个Batch normalization层都拥有一组特定的γ和β,还有μ和σ,后者在训练过程中使用指数加权平均值进行计算。这意味着在推理过程中,Batch normalization就像是对上一层(通常是卷积)的结果进行简单的线性转换。由于卷积也是一个线性变换,这也意味着这两个操作可以合并成一个单一的线性变换!这将删除一些不必要的参数,但也会减少推理时要执行的操作数量。

在实践中怎么做?

用一点数学知识,我们可以很容易地重新对卷积进行排列来处理batch normalization。提醒一下,对一个输入_x_进行卷积之后再进行batch normalization的运算可以表示为:那么,如果我们重新排列卷积的Wb,考虑batch normalization的参数,如下:
我们可以去掉batch normalization层,仍然得到相同的结果!

注意:通常,在batch normalization层之前的层中是没有bias的,因为这是无用的,也是对参数的浪费,因为任何常数都会被batch normalization抵消掉。

这样做的效果怎样?

我们将尝试两种常见的架构:

  • 使用batch norm的VGG16
  • ResNet50

为了演示,我们使用ImageNet dataset和PyTorch。两个网络都将训练5个epoch,看看参数数量和推理时间的变化。

1. VGG16

我们从训练VGG16 5个epoch开始(最终的准确性并不重要):

参数的数量:

单个图像的初始推理时间为:

如果使用了batch normalization折叠,我们有:

以及:

8448个参数被去掉了,更好的是,几乎快了0.4毫秒!最重要的是,这是完全无损的,在性能方面绝对没有变化:

让我们看看它在Resnet50的情况下是怎么样的!

2. Resnet50

同样的,我们开始训练它5个epochs:

初始参数量为:

推理时间为:

使用batch normalization折叠后,有:

和:

现在,我们有26,560的参数被移除,更惊讶的hi,推理时间减少了1.5ms,性能一点也没降。


推荐阅读


    添加极市小助手微信(ID : cvmart2),备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳),即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群:月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~

    △长按添加极市小助手

    △长按关注极市平台,获取 最新CV干货

    觉得有用麻烦给个在看啦~   
    登录查看更多
    1

    相关内容

    【伯克利】再思考 Transformer中的Batch Normalization
    专知会员服务
    41+阅读 · 2020年3月21日
    专知会员服务
    45+阅读 · 2020年3月6日
    模型压缩究竟在做什么?我们真的需要模型压缩么?
    专知会员服务
    28+阅读 · 2020年1月16日
    3倍加速CPU上的BERT模型部署
    ApacheMXNet
    11+阅读 · 2020年7月13日
    如何区分并记住常见的几种 Normalization 算法
    极市平台
    19+阅读 · 2019年7月24日
    不用重新训练,直接将现有模型转换为 MobileNet
    极市平台
    6+阅读 · 2019年3月4日
    如何训练你的ResNet(三):正则化
    论智
    5+阅读 · 2018年11月13日
    [深度学习] AlexNet,GoogLeNet,VGG,ResNet简化版
    机器学习和数学
    20+阅读 · 2017年10月13日
    Layer Normalization原理及其TensorFlow实现
    深度学习每日摘要
    32+阅读 · 2017年6月17日
    Stagnation Detection with Randomized Local Search
    Arxiv
    0+阅读 · 2021年1月28日
    Self-Attention Graph Pooling
    Arxiv
    13+阅读 · 2019年6月13日
    Arxiv
    19+阅读 · 2018年6月27日
    Arxiv
    7+阅读 · 2018年3月22日
    VIP会员
    相关资讯
    3倍加速CPU上的BERT模型部署
    ApacheMXNet
    11+阅读 · 2020年7月13日
    如何区分并记住常见的几种 Normalization 算法
    极市平台
    19+阅读 · 2019年7月24日
    不用重新训练,直接将现有模型转换为 MobileNet
    极市平台
    6+阅读 · 2019年3月4日
    如何训练你的ResNet(三):正则化
    论智
    5+阅读 · 2018年11月13日
    [深度学习] AlexNet,GoogLeNet,VGG,ResNet简化版
    机器学习和数学
    20+阅读 · 2017年10月13日
    Layer Normalization原理及其TensorFlow实现
    深度学习每日摘要
    32+阅读 · 2017年6月17日
    Top
    微信扫码咨询专知VIP会员