48小时单GPU训练DistilBERT!这个检索模型轻松达到SOTA

2021 年 12 月 7 日 PaperWeekly


©PaperWeekly 原创 · 作者 | Maple小七

单位 | 北京邮电大学

研究方向 | 自然语言处理




论文标题: 

Efficiently Teaching an Effective Dense Retriever with Balanced Topic Aware Sampling

收录会议:

SIGIR 2021

论文链接:

https://arxiv.org/abs/2104.06967

代码链接:

https://github.com/sebastian-hofstaetter/tas-balanced-dense-retrieval


基于 BERT 的稠密检索模型虽然在 IR 领域取得了阶段性的成功,但检索模型的训练、索引和查询效率一直是 IR 社区关注的重点问题,虽然超越 SOTA 的检索模型越来越多,但模型的训练成本也越来越大,以至于要训练最先进的稠密检索模型通常都需要 8×V100 的配置。而采用本文提出的 TAS-Balanced 和 Dual-supervision 训练策略,我们仅需要在单个消费级 GPU 上花费 48 小时从头训练一个 6 层的 DistilBERT 就能取得 SOTA 结果,这再一次证明了当前大部分稠密检索模型的训练是缓慢且低效的。



绪言


在短短的两年时间内,当初被质疑是 Neural Hype 的 Neural IR 现在已经被 IR 社区广泛接受,不少开源搜索引擎也逐渐支持了基于 BERT 的稠密检索(dense retrieval),基本达到了开箱即用的效果。其中,DPR 提出的 是当前最主流的稠密检索模型,然而众所周知的是, 的可迁移性远不如 BM25 这类 learning-free 的传统检索方法,想要在具体的业务场景下使用 并取得理想的结果,我们通常需要准备充足的标注数据进一步训练检索模型。

因此,如何高效地训练一个又快又好的 一直是 Neural IR 的研究热点。目前来看,改进 主要有两条路线可走,其中一条路线是改变 batch 内的样本组合,让模型能够获取更丰富的对比信息:

  • 优化模型的训练过程:这类方法的代表作是 ANCE 提出的动态负采样策略,其基本思路是在训练过程中定期刷新索引,从而为模型提供更优质的难负样本,而不是像 DPR 那样仅从 BM25 中获取负样本。在此基础上,LTRe 指出目前的检索模型其实是按 learning to rank 来训练的,因为训练过程中模型仅能看到一个 batch 内的样本,但如果我们只训练 query encoder,冻结 passage embedding,我们就可以按照 learning to retrieve 的方式计算全局损失,而不是仅计算一个 batch 的损失。除此之外,RocketQA 提出了 Cross Batch 技巧来增大 batch size,由于检索模型采用对比损失训练,因此理论上增大 batch size 带来的基本都是正收益。


然而,这三种策略都在原始的 的基础上增加了额外的计算成本,并且实现都比较复杂。除此之外,我们也可以利用知识蒸馏(knowledge distillation)为模型提供更优质的监督信号:

  • 优化模型的监督信号:  我们可以将表达能力更强但运行效率更低的 当作 teacher model 来为 提供 soft label。 在检索模型的训练中,知识蒸馏的损失函数有很多可能的选择,本文仅讨论 pairwise loss 和 in-batch negative loss,其中 in-batch negative loss 在 pairwise loss 的基础上将 batch 内部其他 query 的负样本也当作当前 query 的负样本,这两类蒸馏 loss 的详细定义后文会讲。

本文同样是在上述两个方面对 做出优化,在训练过程方面,作者提出了 Balanced Topic Aware Sampling(TAS-Balanced)策略来构建 batch 内的训练样本;在监督信号方面,作者提出了将 pairwise loss 和 in-batch negative loss 结合的 dual-supervision 蒸馏方式。



Dual Supervision


越来越多的证据表明知识蒸馏能够带来稠密检索模型性能的提升,本文将 提供的 pairwise loss 和 提供的 in-batch negative loss 结合起来为 提供监督信号,下面先简单介绍一下 teacher model 和 student model。

Teacher Model:  

是当前应用最为广泛的排序模型,它简单地将 query 和 passage 的拼接作为 的输入序列,然后对 输出向量做一个线性变换得到相关性 打分:


是一个经典的多向量表示模型,它将 query 和 passage 之间的交互简化为 max-sum 来克服 无法缓存 passage 向量的问题,其基本思路是首先对 query 和 passage 分别编码


然后计算每个 query term 和每个 passage term 的点积相似度,按 doc term 做 max-pooling 并按 query term 求和获取 query 和 passage 的相似度:


虽然理论上 可以对 passage 建立离线索引,但存储 passage 多向量表示的资源开销是非常大的,并且该存储成本随着语料库的 term 数量呈线性增长,再加上 max-sum 的操作也会带来额外的计算成本,因此这里我们将 当作 的 teacher。

Student Model:  

DPR 提出的 仅使用二元标签和 BM25 生成的负样本训练模型, 首先将 query 和 passage 独立编码为单个向量:


然后计算 的点积相似度:


在检索阶段, 首先对 query 编码,然后利用 faiss 做最大内积检索,下表展示了在单个消费级 GPU 上 6 层 DistilBERT 在 800 万 passage 集合上的检索速度。


2.1 Dual-Teacher Supervision


如果仅看监督信号的质量, 提供的 in-batch negative loss 当然是最优质的。然而, 虽然在表达能力上比 更强,但它实际上很少用于计算 in-batch negative loss,因为 需要单独编码每个 query-passage 样本对,所以其计算开销随着 batch size 二次增长,而 解耦了 query 和 passage 的表示,因此它的开销是随着 batch size 线性增长的,其 in-batch negative loss 的计算效率要高得多。

因此这里我们只让 提供 pairwise loss,具体来说,我们首先利用训练好的 对训练集中所有的 query-passage 样本对打分,然后计算 的蒸馏损失,蒸馏损失的具体形式有很多选择,这里作者选择了 Margin-MSE loss 作为 pairwise loss:


其中 分别为

我们同时让 提供 in-batch negative loss:


in-batch negative loss 中的 其实也可以替换成别的 loss,作者在后续实验中也尝试了一些看起来更有效的 listwise loss,然而实验结果表明 Margin-MSE loss 依旧是最佳的选择。因此,作者最终提出的蒸馏 loss 是 pairwise loss 和 in-batch negative loss 的加权平均,在后续实验中,作者设加权系数



Balanced Topic Aware Sampling


在原始的 的训练中,我们首先随机地从 query 集合 中采样 ,然后再为每个 随机采样一个正样本 和一个负样本 组成一个 batch:



其中 表示从集合 无放回地采样 个样本。由于训练集是非常大的,每个 batch 中的 几乎都是没有相关性的,但是当我们计算 in-batch negative loss 时,query 不仅和自身的 交互,也和别的 query 对应的 交互,然而,由于 对模型来说大概率是简单样本,因此它所能提供的信息增益是非常少的,这也导致了每个 batch 所能提供的信息量偏少,使得检索模型需要长时间的训练才能收敛。

3.1 TAS

针对这个问题,作者提出了 Topic Aware Sampling(TAS)策略来构建 batch 内的训练样本,具体来说,在训练之前,我们先利用 k-means 算法将 query 聚类到 k 个 cluster 中:


其中 query 的表示 由基线模型 提供, 的聚类中心,这样,每个 cluster 中的 query 都是主题相关的,在构建 batch 的时候,我们可以先从 cluster 的集合 中随机抽样 个 cluster,然后在每个 cluster 上随机抽样 个 query


在后续的实验中,作者为 40 万个 query 创建了  k=2000  个 cluster,并设 batch size 大小为 b=32,组建 batch 时随机抽样的 cluster 数量为 n=1,这样,每个 batch 中的样本都来自于同一个 cluster。如下图所示,相比于在整个 query 集合上随机抽样,TAS 策略生成的 batch 内部的 query 有更高的主题相似性。


3.2 TAS-balanced


在组建 batch 的时候,我们还需要为每个采样到的 query 配置正负样本对 。不难想到,几乎所有 query 对应的 都比 少得多,如果用独立随机抽样的方式获取 ,那么组成的 的 margin(也就是 )大概率是很大的,因此大部分 对模型来说是简单样本,因为模型很容易将 分开。

因此,我们可以在 TAS 策略的基础上进一步均衡 batch 内正负样本对的 margin 分布以减少 high margin(low information)的正负样本对。具体来说,针对每个 query,我们首先计算它对应的样本对集合的最小 margin 和最大 margin,然后将该区间分割为 个子区间,在为 query 配置 时,我们首先从这 个子区间中随机选择一个子区间,然后从 margin 落在该子区间内的 集合中随机采样并组成一个训练样本:


这样,在构建一个 batch 的时候,我们首先需要采样一个 cluster,然后采样 b  个 query,接下来为每个 query 采样一个 margin 子区间,最后在该子区间上采样一个正负样本对,这整套流程就是所谓的 TAS-balanced batch sampling:



需要注意的是,TAS-balanced 策略不会影响模型的训练速度,因为 batch 的构建是可以并行处理或者预先处理好的。TAS-balanced 策略组建的 batch 对模型来说整体的难度更大,因此为模型提供了更多的信息量,即使采用较小的 batch size,模型也能很好地收敛。如下表所示,我们可以在消费级显卡上(11GB 内存)高效地训练 而不需要昂贵的 8×V100 的配置,因为该方法不需要像 ANCE 那样重复刷新索引,也不需要像 RocketQA 那样进行超大批量的训练。


3.3 Experiment

作者选择 MSMACRO-Passage 官方提供的 4000 万正负样本对作为检索模型的训练集,并选择 MSMACRO-DEV(sparsely-judged,包含 6980 个 query)和 TREC-DL 19/20(densely-judged,包含 43/54 个 query)作为验证集。同时  均采用 6 层的 DistilBERT 初始化,且没有使用预训练的检索模型。



Results


4.1 Source of Effectiveness


首先我们对作者提出的 Dual-supervision 做消融实验,如下表所示。对于基于 pairwise loss 的知识蒸馏,Margin-MSE loss 的优越性已经被之前的论文证明,所以这里仅讨论 in-batch negative loss 的有效性。作者对比了基于 listwise loss 的 KL Divergence、ListNet 和 Lambdarank,实验结果表明这些损失的效果都不如 Margin-MSE loss,尤其是在 R@1K 上面。


为什么 pairwise 的 Margin-MSE 比 listwise loss 更好呢?因为 Margin-MSE 不仅仅是让模型去学习 teacher 所给出的排序,同时还学习 teacher score 的分布,由于 batch 内部样本的 order 实际上是有偏的,它并不能准确刻画样本间的真实距离,因此比起学习 order,学习 score 分布其实是一种更精确的方式。另外,由于 teacher 和 student 在训练阶段所使用的损失是一致的,这也会让 student 更容易学习到 teacher 的 score 分布。

接下来我们对 TAS-Balanced 策略做消融实验,如下表所示。总体来说,TAS-balanced 策略加上 Dual-supervision 蒸馏可以在各个数据集上取得最优性能。值得关注的是,在单独的 pairwise loss 的监督下使用 TAS 策略其实并不能带来明显的提升,这是因为 TAS 是面向 in-batch negative loss 设计的,使用 pairwise loss 训练时,batch 内的样本是没有交互的,因此 TAS 也就不会起作用。而 TAS-balanced 策略会影响正负样本对的组成方式,因此会对 pairwise loss 产生一定的影响。


4.2 Comparing to Baselines


下表对比了作者的模型和其他模型的性能,对比最后三行,我们可以发现一个有趣的现象:增大 batch size 在 TREC-DL 这类 densely-judge 的数据集上没有带来提升,但在 MSMACRO-DEV 这类 sparsely-judge 的数据集上会带来持续的提升。 因此作者猜想增大 batch size 会导致模型在 sparsely-judge 的 MSMACRO 上过拟合,RocketQA 的 SOTA 表现可能仅仅是因为它的 batch size 够大。


4.3 TAS-Balanced Retrieval in a Pipeline


为了进一步证明方法的有效性,作者尝试将 TAS-Balance 训练的检索模型应用到召回-排序系统中。众所周知,稠密检索和稀疏检索是互补的,且融合稀疏检索几乎不会影响召回速度,因此作者考虑将稀疏检索的 docT5query 的检索结果和 TAS-balanced 稠密检索模型的结果融合,然后使用最先进的 mono-duo-T5 排序模型对检索结果做重排。


选择不同的召回模型、排序模型和不同大小的候选集,我们可以得到不同延迟水平的检索系统。如上表所示,作者提出的模型在各个延迟水平上均取得了优异的表现。值得注意的是,在高延迟系统中,排序模型 mono-duo-T5 是在 BM25 的召回结果上训练的,这实际上会导致训练测试分布不一致的问题,所以 TAS-B+mono-duo-T5 甚至没能超越 BM25+mono-duo-T5,为了取得更好的性能,我们应该先训召回模型,然后在召回模型的给出召回结果上训练排序模型,这其实也间接反映了当前的排序模型泛化性不足的问题。



Discussion


本篇论文最大的亮点是 TAS-Balanced 策略的高效性,使用作者的模型,我们仅需要在单个消费级 GPU 上从头训练 48 小时就能取得 SOTA 结果,极大地降低了检索模型的训练成本,这在之前是无法想象的。实际上,比起 NLP 社区,IR 社区更加强调模型和数据的 Efficiency,这一课题在将来也一定会受到持续的关注。


特别鸣谢

感谢 TCCI 天桥脑科学研究院对于 PaperWeekly 的支持。TCCI 关注大脑探知、大脑功能和大脑健康。


更多阅读




#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



·



登录查看更多
0

相关内容

【WSDM2022】基于约束聚类学习离散表示的高效密集检索
专知会员服务
26+阅读 · 2021年11月16日
专知会员服务
26+阅读 · 2021年4月22日
专知会员服务
44+阅读 · 2020年3月6日
模型压缩究竟在做什么?我们真的需要模型压缩么?
专知会员服务
27+阅读 · 2020年1月16日
【SIGIR2021】使用难样本优化向量检索模型
专知
4+阅读 · 2021年4月22日
【GitHub】BERT模型从训练到部署全流程
专知
34+阅读 · 2019年6月28日
国家自然科学基金
5+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
11+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Arxiv
19+阅读 · 2021年4月8日
Arxiv
23+阅读 · 2020年9月16日
Heterogeneous Graph Transformer
Arxiv
27+阅读 · 2020年3月3日
VIP会员
相关基金
国家自然科学基金
5+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
11+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2009年12月31日
Top
微信扫码咨询专知VIP会员