↑↑↑关注后"星标"机器学习与推荐算法
无论是商品推荐,还是广告推荐,都大致可以分为召回,预排序(粗排),精排等阶段,如上篇<淘宝搜索中基于embedding的召回>的图所示:
召回最重要的就是要快,高召回率,对准确率可以不必要求太高,所以召回模型往往就是双塔模型,最经典的就是youtube双塔:
后面各种论文大多都说基于样本构造,模型结构,增加特征等方面去优化召回模型,但是不得不提到的是,蒸馏也是个提高召回侧模型效果的一个好方法。
由于受限于线上性能,在广告/商品召回阶段,我们通常采用深度学习双塔模型结构,离线先计算保存好ad/item embedding,线上实时预测出user embedding再通过近邻检索召回相似广告。user塔和ad塔是两个独立的神经网络,而user侧特征和ad侧特征没有交互,损失了很多有用信息,且因为user embedding线上实时inference,这就限制了user塔的特征规模及模型结构复杂度。对于这两个问题,蒸馏模型提供了一种解决方法。以下是蒸馏模型的特点:
-
由于training阶段不要求实时操作,允许训练一个复杂的模型,蒸馏模型可以在training阶段用复杂度高的网络(teacher network)-学到的知识指导较为简单的网络(student network)学习,在serving阶段以较小的计算代价来使用简单网络,同时保持一定的网络预测能力。
对于一些线上serving阶段无法获取的但又对目标有实际意义的特征,如用户与广告或商品的交互特征等,可以在training阶段将这类特征都加入teacher network学习,而线上serving阶段只需获取用于训练student network的基本特征,serving过程只使用student network结构。
可以将集成的知识压缩在简单的模型中。对于一个已经训练好的复杂的模型,如果要集成的话要带来很大的计算开销,而使用蒸馏模型可以用复杂模型指导一系列简单模型学习,根据复杂的大网络和一系列简单模型的输出作为目标,训练一个最终的模型,可不用对复杂模型进行集成。
当然,蒸馏用在召回,更重要的意义是保证召回,预排序(粗排),精排一致性,而不是蒸馏一堆看似高大上的特点。为啥要保证一致性呢?召回侧最终服务于排序,选出排序认可的才是最重要的,如果召回的都不是排序认可的,那排序模型也只能矮子里挑高的选,这样会影响整体的收益。如果召回模型在训练阶段增加对精排的拟合,是不是可以近似达到精排模型在全库搜索的效果呢?
那么推荐系统中蒸馏应该怎么做呢?其实最简单的就是改loss,除了交叉熵损失,可以增加和teacher预估不一致而带来的损失,辅助学习。
其中,L_hard是分类问题中经典的交叉熵损失,是真实标签与模型预测概率之间的交叉熵损失,记为hard loss;λ是超参数,控制teacher模型对student的指导程度;L_soft是teacher模型输出概率与student模型输出概率的交叉熵,记为soft loss,形式如下所示:
也可以用带温度的softmax函数控制teacher信号的传输:
Lsoft也可以用logit直接的mse loss进行学习。大致框架如下图所示:
训练大家可以尝试teacher和student同时训练,也可以先训练好teacher,再蒸馏到student上。在实际使用上,AUC和GAUC都是可以涨一些的。
说到这肯定有人要问了,召回可以学精排,预排序(粗排)可以学精排吗?,当然可以,而且肯定也会有收益。那召回为啥不学预排序(粗排)?毕竟召回直接相连的就是预排序模型。当然也是可以的尝试的。
总结一下,无论是做哪个阶段的模型,只单独优化某个阶段的模型很容易到达瓶颈,尽管每年关于推荐的论文层出不穷,但是真正用上了有效果的却很少。有时要从系统的角度出发去思考模型比单纯去堆砌模型结构效果要大得多。
喜欢的话点个在看吧👇