知识蒸馏的过程是怎样的?与迁移学习的区别在哪里?

知识蒸馏中,训练teacher net和蒸馏训练student net时的数据集是否应该为同一个?总感觉知识蒸馏和迁移学习有莫名的相似。
关注者
156
被浏览
223,349
登录后你可以
不限量看优质回答私信答主深度交流精彩内容一键收藏

转载于:

有人说过:“神经网络用剩的logits不要扔,沾上鸡蛋液,裹上面包糠...” 这两天对知识蒸馏(Knowledge Distillation)萌生了一点兴趣,正好写一篇文章分享一下。这篇文章姑且算是一篇小科普。

1. 从模型压缩开始

各种模型算法,最终目的都是要为某个应用服务。在买卖中,我们需要控制收入和支出。类似地,在工业应用中,除了要求模型要有好的预测(收入)以外,往往还希望它的「支出」要足够小。具体来说,我们一般希望部署到应用中的模型使用较少的计算资源(存储空间、计算单元等),产生较低的时延。

深度学习的背景下,为了达到更好的预测,常常会有两种方案:

1. 使用过参数化的深度神经网络,这类网络学习能力非常强,因此往往加上一定的正则化策略(如dropout);2. 集成模型(ensemble),将许多弱的模型集成起来,往往可以实现较好的预测。这两种方案无疑都有较大的「支出」,需要的计算量和计算资源很大,对部署非常不利

这也就是模型压缩的动机:我们希望有一个规模较小的模型,能达到和大模型一样或相当的结果。当然,从头训练一个小模型,从经验上看是很难达到上述效果的,也许我们能先训练一个大而强的模型,然后将其包含的知识转移给小的模型呢?如何做到呢?

* 下文统一将要训练的小模型称为新模型,将以及训练的大模型称为原模型。

Rich Caruana等人在[1]中指出,可以让新模型近似(approximate)原模型(模型即函数)。注意到,在机器学习中,我们常常假定输入到输出有一个潜在的函数关系,这个函数是未知的:从头学习一个新模型就是从有限的数据中近似一个未知的函数。如果让新模型近似原模型,因为原模型的函数是已知的,我们可以使用很多非训练集内的伪数据来训练新模型,这显然要更可行。

这样,原来我们需要让新模型的softmax分布与真实标签匹配,现在只需要让新模型与原模型在给定输入下的softmax分布匹配了。直观来看,后者比前者具有这样一个优势:经过训练后的原模型,其softmax分布包含有一定的知识——真实标签只能告诉我们,某个图像样本是一辆宝马,不是一辆垃圾车,也不是一颗萝卜;而经过训练的softmax可能会告诉我们,它最可能是一辆宝马,不大可能是一辆垃圾车,但绝不可能是一颗萝卜[2]。

2. 为什么叫「蒸馏」?

接续前面的讨论,我们的目标是让新模型与原模型的softmax输出的分布充分接近。直接这样做是有问题的:在一般的softmax函数中,自然指数 e 先拉大logits之间的差距,然后作归一化,最终得到的分布是一个arg max的近似 ,其输出是一个接近one-hot的向量,其中一个值很大,其他的都很小。这种情况下,前面说到的「可能是垃圾车,但绝不是萝卜」这种知识的体现是非常有限的。相较类似one-hot这样的硬性输出,我们更希望输出更「软」一些。

一种方法是直接比较logits来避免这个问题。具体地,对于每一条数据,记原模型产生的某个logits是 v_i ,新模型产生的logits是 z_i ,我们需要最小化

\frac{1}{2}(z_i-v_i)^2\tag{1}

文献[2]提出了更通用的一种做法。考虑一个广义的softmax函数

q_i=\frac{\exp(z_i/T)}{\sum_j\exp(z_j/T)}\tag{2}

其中 T 是温度,这是从统计力学中的玻尔兹曼分布中借用的概念。容易证明,当温度 T 趋向于0时,softmax输出将收敛为一个one-hot向量(证明可以参考我之前的文章:浅谈Softmax函数,将 \beta 替换为 1/T 即可);温度 T 趋向于无穷时,softmax的输出则更「软」。因此,在训练新模型的时候,可以使用较高的 T 使得softmax产生的分布足够软,这时让新模型(同样温度下)的softmax输出近似原模型;在训练结束以后再使用正常的温度 T=1 来预测具体地,在训练时我们需要最小化两个分布的交叉熵(Cross-entropy),记新模型利用公式 (2) 产生的分布是 q ,原模型产生的分布是 p ,则我们需要最小化

C=-p^\top\log q\tag{3}

在化学中,蒸馏是一个有效的分离沸点不同的组分的方法,大致步骤是先升温使低沸点的组分汽化,然后降温冷凝,达到分离出目标物质的目的。在前面提到的这个过程中,我们先让温度 T 升高,然后在测试阶段恢复「低温」,从而将原模型中的知识提取出来,因此将其称为是蒸馏,实在是妙

当然,如果转移时使用的是有标签的数据,那么也可以将标签与新模型softmax分布的交叉熵加入到损失函数中去这里需要将式 (3) 乘上一个 T^2 ,这是为了让损失函数的两项的梯度大致在一个数量级上(参考公式 (9) ),实验表明这将大大改善新模型的表现(考虑到加入了更多的监督信号)。

3. 与直接优化logits差异相比

由公式 (2)(3) ,对于交叉熵损失来说,其对于新模型的某个logit z_i 的梯度是

\begin{align}\frac{\partial C}{\partial z_i}&=\frac{1}{T}(q_i-p_i)\tag{4}\\ &=\frac{1}{T}\left(\frac{\exp(z_i/T)}{\sum_j\exp(z_j/T)}- \frac{\exp(v_i/T)}{\sum_j\exp(v_j/T)}\right)\tag{5} \end{align}

由于 e^x-1 x 等价无穷小(x\rightarrow 0),易知,当 T 充分大时,有

\begin{align}\frac{\partial C}{\partial z_i}&\approx\frac{1}{T}\left(\frac{1+z_i/T}{\sum_j(1+z_j/T)}-\frac{1+v_i/T}{\sum_j(1+v_j/T)}\right)\tag{6}\\ &=\frac{1}{T}\left(\frac{1+z_i/T}{N+\sum_jz_j/T}-\frac{1+v_i/T}{N+\sum_jv_j/T}\right)\tag{7}\\ \end{align}

假设所有logits对每个样本都是零均值化的,即 \sum_jz_j=\sum_jv_j=0 ,则有

\begin{align}\frac{\partial C}{\partial z_i}&\approx\frac{1}{T}\left(\frac{1+z_i/T}{N}-\frac{1+v_i/T}{N}\right)\tag{8}\\ &=\frac{1}{NT^2}\left(z_i-v_i\right)\tag{9} \end{align}

所以,如果:1. T 非常大,2. logits对所有样本都是零均值化的,则知识蒸馏和最小化logits的平方差(公式 (1) )是等价的(因为梯度大致是同一个形式)。实验表明,温度 T 不能取太大,而应该使用某个适中的值,这表明忽略极负的logits对新模型的表现很有帮助(较低的温度产生的分布比较「硬」,倾向于忽略logits中极小的负值)

4. 实验与结论

Hinton等人做了三组实验,其中两组都验证了知识蒸馏方法的有效性。在MNIST数据集上的实验表明,即便有部分类别的样本缺失,新模型也可以表现得很不错,只需要修改相应的偏置项,就可以与原模型表现相当。在语音任务的实验也表明,蒸馏得到的模型比从头训练的模型捕捉了更多数据集中的有效信息,表现仅比集成模型低了0.3个百分点。总体来说知识蒸馏是一个简单而有效的模型压缩/训练方法。这大体上是因为原模型的softmax提供了比one-hot标签更多的监督信号[3]。

知识蒸馏在后续也有很多延伸工作。在NLP方面比较有名的有Yoon Kim等人的Sequence-Level Knowledge Distillation 等。总的来说,对一些比较臃肿、不便部署的模型,可以将其「知识」转移到小的模型上。比如,在机器翻译中,一般的模型需要有较大的容量(capacity)才可能获得较好的结果;现在非常流行的BERT及其变种,规模都非常大;更不用提,一些情形下我们需要将这些本身就很大的深度模型集成为一个ensemble,这时候,可以用知识蒸馏压缩出一个较小的、「便宜」的模型。

另外,在多任务的情境下,使用一般的策略训练一个多任务模型,可能达不到比单任务更好的效果,文献[3]探索了使用知识蒸馏,利用单任务的模型来指导训练多任务模型的方法,很值得参考。

补充

鉴于评论区有知友对公式 (4) 有疑问,简单补充一下这里梯度的推导(其实就是交叉熵损失对softmax输入的梯度,LOL)。

* 这部分有一点繁琐,能接受公式 (4) 的读者可以跳过。

链式法则,有

\frac{\partial C}{\partial z}=\frac{\partial q}{\partial z}\frac{\partial C}{\partial q}\tag{10}

注意到 p 是原模型产生的softmax输出,与 z 无关。

后一项 \partial C/\partial q 比较容易得到,因为 C=\sum_{i=1}^{n}-p_i\log q_i ,所以

\frac{\partial C}{\partial q_i}=-\frac{p_i}{q_i}\tag{11}

\partial C/\partial q 是一个 n 维向量

\frac{\partial C}{\partial q}=\left[\begin{matrix} -\frac{p_1}{q_1}\\ -\frac{p_2}{q_2}\\ \vdots\\ -\frac{p_n}{q_n} \end{matrix}\right]\tag{12}

前一项 \partial q/\partial z 是一个 n\times n 的方阵,分类讨论可以得到。参考公式 (2) ,记 Z=\sum_k \exp(z_k/T) ,由除法的求导法则,输出元素 q_i 对输入 z_j 的偏导是

\frac{\partial q_i}{\partial z_j}=\frac{1}{Z^2}\left(Z\frac{\partial \exp(z_i/T)}{\partial z_j}-\exp(z_i/T)\boxed{\frac{\partial Z}{\partial z_j}} \right)\tag{13}

注意上面右侧加方框部分,可以进一步展开

\frac{\partial Z}{\partial z_j}=\frac{1}{T}\exp(z_j/T)\tag{14}

这样,代入公式 (13) ,并且将括号展开,可以得到

\begin{align} \frac{\partial q_i}{\partial z_j}&=\frac{1}{Z}\frac{\partial \exp(z_i/T)}{\partial z_j}-\frac{1}{TZ^2}\exp(z_i/T)\exp(z_j/T)\tag{15}\\ &=\frac{1}{Z}\frac{\partial \exp(z_i/T)}{\partial z_j}-\frac{1}{T}\frac{\exp(z_i/T)}{Z}\frac{\exp(z_j/T)}{Z}\tag{16}\\ &=\frac{1}{Z}\boxed{\frac{\partial \exp(z_i/T)}{\partial z_j}}-\frac{1}{T}q_iq_j\tag{17}\\ \end{align}

左侧方框内偏导可以分类讨论得到

\frac{\partial \exp(z_i/T)}{\partial z_j}=\begin{cases} \frac{1}{T}\exp(z_i/T),\ &\text{if }i=j\\ 0,&\text{if }i\neq j \end{cases}\tag{18}

带入式 (17) ,得到

\begin{align} \frac{\partial q_i}{\partial z_j} &=\begin{cases} \frac{1}{T}\left(\frac{\exp(z_i/T)}{Z}-q_iq_j\right) ,\ &\text{if }i=j\\ -\frac{1}{T}q_iq_j,&\text{if }i\neq j \end{cases}\tag{19}\\ &=\begin{cases} \frac{1}{T}\left(q_i-q_iq_j\right) ,\ &\text{if }i=j\\ -\frac{1}{T}q_iq_j,&\text{if }i\neq j \end{cases}\tag{20} \end{align}

所以 \partial q/\partial z 形式如下

\frac{\partial q}{\partial z}=\frac{1}{T}\left[ \begin{matrix} q_1-q_1^2 & -q_1q_2 & \cdots & -q_1q_n\\ -q_2q_1 & q_2-q_2^2 & \cdots & -q_2q_n\\ \vdots & \vdots & \ddots &\vdots\\ -q_nq_1 & -q_nq_2 & \cdots & q_n-q_n^2 \end{matrix} \right]\tag{21}

代入式 (10) ,可得

\begin{align}\frac{\partial C}{\partial z}&=\frac{1}{T}\left[ \begin{matrix} q_1-q_1^2 & -q_1q_2 & \cdots & -q_1q_n\\ -q_2q_1 & q_2-q_2^2 & \cdots & -q_2q_n\\ \vdots & \vdots & \ddots &\vdots\\ -q_nq_1 & -q_nq_2 & \cdots & q_n-q_n^2 \end{matrix} \right] \left[\begin{matrix} -\frac{p_1}{q_1}\\ -\frac{p_2}{q_2}\\ \vdots\\ -\frac{p_n}{q_n} \end{matrix}\right]\tag{22}\\ &=\frac{1}{T}\left[\begin{matrix} -p_1+\sum_kp_kq_1\\ -p_2+\sum_kp_kq_2\\ \vdots\\ -p_n+\sum_kp_kq_n \end{matrix}\right]\tag{23}\\ &=\frac{1}{T}\left[\begin{matrix} -p_1+q_1\\ -p_2+q_2\\ \vdots\\ -p_n+q_n \end{matrix}\right]\tag{24}\\ &=\frac{1}{T}(q-p)\tag{25} \end{align}


所以有公式 (4)\frac{\partial C}{\partial z_i}=\frac{1}{T}(q_i-p_i)


参考

[1] Caruana et al., Model Compression, 2006

[2] Hinton et al., Distilling the Knowledge in a Neural Network, 2015

[3] Kevin Clark et al., BAM! Born-Again Multi-Task Networks for Natural Language Understanding