Hinton论文知识蒸馏里面的温度T到底是干什么用的?

实际上反向求导这里我也有点看不懂,相对熵的梯度是p/q,其中p是复杂模型的预测概率,q是简单模型的预测概率。当p=q的时候简单模型就不应该继续变化了,…
关注者
29
被浏览
41,698
登录后你可以
不限量看优质回答私信答主深度交流精彩内容一键收藏

温度T其实就是用来软化标签的。所谓硬标签指的是one hot向量,这样的标签里面只有一个信息,就是这个分类是A而不是B。但teacher网络的预测并不是one hot向量而是一个概率值,如果这个概率中有类似这样的[0.6, 0.3, 0.05, 0.05]那么说明这个图片的分类是第0类但与第1类更接近,与第2和第3类并不接近。这就是软标签的概念,为了给网络提供更多的信息。但一个训练好的网络,它对于很多物体的分类是很确定的,它的输出大部分都会是[0.98, 0.01, 0.006, 0.0004] 这样的,这样的标签更接近硬标签。而温度T加在这样的标签上,就会使teacher的标签没有那么硬,让置信度高的降低一些,低的升高一些,将刚才的输出对应转化为[0.7, 0.2, 0.12, 0.08]这样的(只是举例,数值没有仔细计算),从而提供更多的信息给student。从另一个角度讲,加入了温度T的teacher是一个弱化了的teacher,避免teacher过强导致学习学不进去。

另外蒸馏的过程中需要对T逐渐降温,最终的使T=1,p=q的时候没有loss。如果一直保持较高的温度是不能得到好的结果的。