加入极市专业CV交流群,与6000+来自腾讯,华为,百度,北大,清华,中科院等名企名校视觉开发者互动交流!更有机会与李开复老师等大牛群内互动!
同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注 极市平台 公众号 ,回复 加群,立刻申请入群~
来源:诺亚实验室
华为诺亚方舟实验室联合北京大学和悉尼大学发布论文《DAFL:Data-Free Learning of Student Networks》,提出了在无数据情况下的网络蒸馏方法(DAFL),比之前的最好算法在MNIST上提升了6个百分点,并且使用resnet18在CIFAR-10和100上分别达到了92%和74%的准确率(无需训练数据),该论文已被ICCV2019接收。
许多研究表明,训练好的判别器具有提取图像特征的能力,提取到的特征可以直接用于分类任务,所以,由于待压缩网络使用真实图片进行训练,也同样具有提取特征的能力,从而具有一定的分辨图像真假的能力。于是,我们把待压缩网络作为一个固定的判别器,以此来训练我们的生成网络。
然而,在传统GAN中,传统的判别器的输出是判定图片是否真假,只要让生成网络生成在判别器中分类为真的图片即可训练,但是,我们的待压缩网络为分类网络,其输出是分类结果,所以,我们需要重新设计生成网络的目标。通过观察真实图片在分类网络的响应,我们提出了以下损失函数。
在图像分类任务中,神经网络的训练采用的是交叉熵损失函数,在训练完成后,真实图片在网络中的输出将会是一个one-hot的向量,即分类类别对应的输出为1,其他的输出为0。于是,我们希望生成图片也具有类似的性质,我们的交叉熵损失函数定义为:
(1)
其中就是标准的交叉熵函数,由于生成图片并没有一个真实的标签,我们直接将其输出最大值对应的标签设定为它的伪标签。
在神经网络的训练中,由卷积核提取的特征也是输入图片的一种重要表示。先前的许多工作表明,卷积核提取的特征包含着图片的许多重要信息,将训练数据输入训练好的深度网络中,卷积核会产生更大的响应(相比于噪声或与此网络无关的数据),基于此,我们提出了特征激活损失函数定义为:
(2)
目标是让生成图像在待压缩网络中的特征响应值更大,这里我们采用了1范数来优化,原因是1范数相比于2范数会产生更加稀疏的值,而神经网络的响应也常常是稀疏的。
此外,为了让神经网络更好的训练,真实的训练数据对于每个类别的样本数目通常都保持一致,例如MNIST每个类别都含有6000张图片。于是,为了让生成网络产生各个类别样本的概率基本相同,我们引入信息熵,并定义了信息熵损失函数:
(3)
其中为标准的信息熵,信息熵的值越大,对于生成的一组样本来说,每个类别的数目就越平均,从而保证了生成样本的类别平均。
最后,我们将这三个损失函数组合起来,就可以得到我们生成器总的损失函数:
(4)
通过优化以上的损失函数,训练得到的生成器可以和真实的样本在待压缩网络具有类似的响应,从而更接近真实样本。
除了训练样本的缺失,需要被压缩的神经网络常常是只提供了输入和输出的接口,网络的结构和参数都是未知的。另外,本发明提出的生成网络生成的训练样本是无标注的,基于这两点,我们引入了教师学生网络学习范式,利用蒸馏算法实现利用未标注生成样本对黑盒网络的压缩。
蒸馏算法最早由Hinton提出,待压缩网络(教师网络)为一个具有高准确率但参数很多的神经网络,初始化一个参数较少的学生网络,通过让学生网络的输出和教师网络相同,学生网络的准确率在教师的指导下得到提高。
于是,我们使用交叉熵损失来使得学生网络的输出符合教师网络的输出,具体的损失函数为:
(5)
图1 Data-free Learning
我们在MNIST、CIFAR、CelebA三个数据集上分别进行了实验。
-完-
*延伸阅读
添加极市小助手微信(ID : cv-mart),备注:研究方向-姓名-学校/公司-城市(如:目标检测-小极-北大-深圳),即可申请加入目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群,更有每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流,一起来让思想之光照的更远吧~
△长按添加极市小助手
△长按关注极市平台
觉得有用麻烦给个在看啦~