ICLR 2020 | 可提速3000倍的全新信息匹配架构(附代码复现)

2020 年 4 月 4 日 PaperWeekly


©PaperWeekly 原创 · 作者|周树帆

学校|上海交通大学硕士生

研究方向|自然语言处理


今天聊一篇 FAIR 发表在 ICLR 2020 上的文章:Poly-encoders: Transformer Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring



论文标题:Poly-encoders: Transformer Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring

论文来源:ICLR 2020

论文链接:https://arxiv.org/abs/1905.01969



和一些花里胡哨但是没有卵用的论文不同,这篇文章可谓大道至简。该文用一种非常简单但是有效的方式同时解决了 DSSM 式的 Bi-encoder 匹配质量低的问题和 ARC-II、BERT 等交互式的 Cross-encoder 匹配速度慢的问题。



背景


众所周知,常见的搜索、检索式问答、自然语言推断等任务,它们本质上都是一种相关性匹配任务:给定一段文本作为 query,然后匹配出最为相关的文档或答案然后返回给用户。

目前主流的文本相关性匹配架构有两大类:以 DSSM 为代表的 Siamese Network 架构、以及形如 ARC-II、ABCNN 或 BERT(基于 Self-Attention)的交互式匹配架构。

1.1 Siamese Network

如图 1 所示,Siamese Network 式(本篇文章又称其为 Bi-encoder)的匹配方案会利用 2 个网络分别将 query 和 candidates 编码成   和   ,最后再通过一个相关性判别函数(通常为 cosine)计算两个 vec 之间的相似度。

这种方案的最大特点就是 query 和 candidates 直到最后的相关性判别函数时才发生交互,所以会对模型的匹配性能产生一定的影响。

但是这种完全独立的编码方式使得我们可以离线计算好所有 candidates 的向量,线上运行时只需计算 query 的向量然后匹配已有向量即可。总的来说,这种方案匹配速度极快,但是匹配质量不能达到最佳。


▲ 图1. Siamese Network(本篇论文又称其为Bi-encoder)


1.2 交互式匹配


如图 2 所示,交互式匹配(本文记作 Cross-encoder)的核心思想是则是 query 和 candidates 时时刻刻都应相互感知,相互交融,从而更深刻地感受到相互之间是否足够匹配。

早期的交互方案如 ARC-II、ABCNN 等会计算   和   之间的word embedding相似度、Q、C 分别过  RNN 之后的   、   之间的相似度,最后再用一些 CNN 之类的方法整合结果,然后用 MLP 做二分类判别是相关还是不相关。

▲ 图2. 交互式匹配示意图(图中为ARC-II)


另外在 BERT 兴起之后,如图 3 所示般将 query 和 candidate 拼成一句话,然后利用 self-attention 完成 query 和 candidate 之间的交互的模型也大量涌现,并且取得了非常显著的成果。本篇论文实现的 Cross-encoder 也是基于图 3 的架构。

相较于 Siamese Network,这类交互式匹配方案可以在 Q 和 C 之间实现更细粒度的匹配,所以通常可以取得更好的匹配效果。

但是很显然,这类方案无法离线计算 candidates的表征向量,每处理一个 query 都只能遍历所有 (query, candidate) 的 pairs 依次计算相关性,所以这类方案相当耗时(当然也有很多提速手段,不过那不是本文的重点)。


▲ 图3. Cross-encoder




Poly-Encoder


Bi-encoder (Siamese Network) 和 Cross-encoder(交互式网络)都有各自显著的优点和缺点,而本文提出的 Poly-encoder 架构同时集成了两类方案的优点并避免了缺点。


▲ 图4. Poly-encoder


Poly-encoder 如图 4 所示。Poly-encoder 的思想非常简单(简单到论文里仅用了 2 段文字),按我的个人理解描述:

Bi-encoder 的主要问题在于它要求 encoder 将 query 的所有信息都塞进一个固定的比较 general 的向量中,这导致最后   和   计算相似度时已经为时过晚,很多细粒度的信息丢失了(e.g. query 为“我要买苹果”),所以无法完成更精准的匹配。

这就有点像 word2vec 静态词向量:即使一个词有多种语义,它的所有语义也不得不塞进一个固定的词向量。

为了克服这个问题,Poly-encoder 的方案就是每个 query 产生 m 个不同的   ,接着再根据   动态地将 m 个   集成为最终的   (其实有点像封面图那样,有一点用 m 个向量组合出最终的 Low Poly(baike.baidu.com/item/Lo)化向量的味道),最后再计算   和   的匹配度。


用论文里的话来说:


论文中的 ctxt 指代 context,相当于 query;cand 指代 candidate。上面这段论文建议我们可以随机初始化 m 个通过 dot product 计算 attention,从而将长度为 N 的 context 编码成 m 个向量   (即   )。

接着:


我们再用 candidate 对应的向量   计算 m 个   的 attention,进而得到最终的 

很显然,Poly-encoder 架构在实际部署时是可以离线计算好所有 candidates 的向量的,所以只需要计算 query 对应的 m 个   向量,再通过简单的 dot product 就可以快速计算好对应每个 candidate 的“动态的”  向量。

看起来 Poly-encoder 享有 Bi-encoder 的速度,同时又有实现更精准匹配的潜力。我们通过实验来一探究竟。



实验


本文选择了检索式对话数据集 ConvAI2、DSTC 7、Ubuntu v2 数据集以及 Wikipedia IR 数据集进行实验。

训练 Bi-encoder 和 Poly-encoder 时由于这两类模型的特性,负采样方式为:在训练过程中,使用同一个 batch 中的其他 query 对应的 response 作为负样本(如果难以理解,可以稍后结合复现代码来理解)。

而 Cross-encoder 的负采样方式为:在开始训练之前,随机采样 15 个 responses 作为负样本。


3.1 检索质量


图5给出了一些 baseline 模型以及本文的基于预训练 BERT 的 Bi-encoder、Poly-encoder 以及 Cross-encoder 在各个数据集上的表现。

当然我们很容易发现,本文的所有模型由于以预训练的BERT为基础,他们的表现都要显著超出不使用 BERT的那些 baseline 们。所以我们只需要关注 Bi、Poly 和 Cross 三种架构之间的表现差异即可。

实验结果表明即使仅增设少数几个 code(用于计算 attention 产生向量),Poly-encoder 的表现也要远优于 Bi-encoder。

实验结果还表明,Poly-encoder 的表现会随着 code 个数的增加而逐渐增加,并且慢慢逼近 Cross-encoder 的结果(个人认为 Cross-encoder 的表现应该是 Poly-encoder 的上界,不过偶尔也可能会因为一些偶然因素导致 Poly-encoder 反超 Cross-encoder 的情况)。

另外,为了体现 Cross-encoder 在速度上的局限性,作者还很有意思地跳过了 Cross-encoder 在 Wikipedia IR 上的测评并写到:“In addition, Cross-encoders are also too slow to evaluate on the evaluation setup of that task, which has 10k candidates”。

▲ 图5. 模型表现汇总


3.2 检索速度


图 5 的实验结果已经表明 Poly-encoder 的检索质量明显优于 Bi-encoder 架构,且能逼近 Cross-encoder 架构的效果。剩下的关键问题就是 Poly-encoder 是否会显著增加检索耗时?

图 6 给出了各模型在 ConvAI2 数据集上的检索耗时。

Bi-encoder 理所当然是最快的架构,当 candidates 为 100k 时,在 CPU 和 GPU 环境下其检索耗时分别为 160ms 和 22ms;而 Cross-encoder 显然是最慢的一个:同样实验条件下其检索耗时分别约为 2.2M (220 万) ms 和 266K (26.6 万) ms。

反观 Poly-encoder,以 Poly-encoder 360 为例,该模型可以达到远超 Bi-encoder、接近甚至反超 Cross-encoder 的检索质量,但其检索速度确比 Cross-encoder 足足快了约 2600-3000 倍!

▲ 图6. 各模型在ConvAI2数据集上的检索耗时


论文小结


总的来说,本文的出发点就是希望找到一个速度快但质量不足的 Bi-encoder 架构和质量高但速度慢的 Cross-encoder 架构的折中。

本文提出的 Poly-encoder 的核心思想虽然非常简单,但是却十分有效(亲测),确实在很多场景下可以作为 Bi-encoder 的替代,甚至在一些对速度要求较高的场景下可以作为 Cross-encoder 的替代。

方案简洁固然是本文的一大优点,不过这也给未来的研究留下了空间。相信未来很快就会有许多基于 Poly-encoder 的改进版出现。



复现结果分享


在读完论文后的第一时间,我就尝试了复现工作。我的复现结果表明,Poly-encoder 不管是收敛速度还是模型上限,都要显著优于 Bi-encoder,且 Poly-encoder 几乎不增加额外的显存负担,对训练速度的影响也几乎可以忽略。完整代码位于


https://github.com/sfzhou5678/PolyEncoder

5.1 关键代码分析

Poly-encoder 的实现非常简单,只需在 Bi-encoder 的基础上略加修改即可。接下来我将介绍实现 Poly-encoder 的核心代码。
我们首先用 nn.embedding 来作为 m 个 poly_codes 的值, 然后 forward 的时候根据m的值产生对应个数的 poly_codes,这些 codes  将用于计算不同的  attention weights,以产生多个 vec_ctxt(即 vec_q)。
这里我令 poly_code_ids+=1 是为了让 context_encoder 和 response_encoder 对称,所以把 0 号 id 留给了 response_encoder。
self.poly_code_embeddings = nn.Embedding(self.poly_m + 1, config.hidden_size)

poly_code_ids = torch.arange(self.poly_m, dtype=torch.long, device=context_input_ids.device)
poly_code_ids += 1
poly_code_ids = poly_code_ids.unsqueeze(0).expand(batch_size, self.poly_m)
poly_codes = self.poly_code_embeddings(poly_code_ids)


接着,我们用这些 poly_codes 和 bert 的输出做 attention 得到 context_vecs:


def dot_attention(q, k, v, v_mask=None, dropout=None):
  attention_weights = torch.matmul(q, k.transpose(-1-2))
  if v_mask is not None:
    attention_weights *= v_mask.unsqueeze(1)
  attention_weights = F.softmax(attention_weights, -1)
  if dropout is not None:
    attention_weights = dropout(attention_weights)
  output = torch.matmul(attention_weights, v)
  return output

state_vecs = self.bert(context_input_ids, context_input_masks, context_segment_ids)[0]  # [bs, length, dim]
context_vecs = dot_attention(poly_codes, state_vecs, state_vecs, context_input_masks, self.dropout) #[bs, m, dim]


得到 response_vec 的方式类似,不再赘述。最后,只需根据 response_vec 给 context_vecs 做一次 attention 得到 final_context_vec 即可:


if labels is not None:
      responses_vec = responses_vec.view(1, batch_size, -1).expand(batch_size, batch_size, self.vec_dim)

final_context_vec = dot_attention(responses_vec, context_vecs, context_vecs, None, self.dropout)


在 loss function 方面,虽然我们可以在准备数据的时候就为每个样本做 N 次负采样,但是在 Bi-encoder 或 Poly-encoder 这种产生 response_vec 和 query 完全独立的场景下,可以将同一个 batch 内的其他 response 作为负样本来避免重复计算,有效提升训练效率。

具体实现时,我们计算 context_vec_i 和 response_vec_j 的点乘,从而产生一个 [bs, bs] 的余弦相似度矩阵,这个相似度矩阵就是 context_vec_i 和 batch 内的每一个 response_vec 的相似度。

由于我们的目标是最大化 context_vec_i 和对应的正样本,即 response_vec_i 的相似度,所以我们可以做一个 [bs,bs] 的单位矩阵作为 label,最后应用交叉熵产生训练用的 loss。

我的代码中在 dot_product 后面还乘了系数 5,这就是一个用于缓和 softmax 取值的参数,其具体取值通常需要实验来确定,这里的 5 只是我的经验值。


# 因为要算余弦相似度,所以给向量都归一化一下,之后直接点乘即可
context_vec = F.normalize(context_vec, 2-1)
responses_vec = F.normalize(responses_vec, 2-1)

responses_vec = responses_vec.squeeze(1)
dot_product = torch.matmul(context_vec, responses_vec.t())  # [bs, bs]
mask = torch.eye(context_input_ids.size(0)).to(context_input_ids.device)
loss = F.log_softmax(dot_product * 5, dim=-1) * mask
loss = (-loss.sum(dim=1)).mean()


5.2 实验结果


我使用的实验数据是论文中所用的 Ubuntu V2,实验设备是我笔记本上的一个 1066 显卡。当然为了实验跑得更快,我没有使用论文中所用的 bert-base,而是一个预训练过的仅 4 层的 bert-small。

另外,此实验中所用的 batchsize、文本长度、历史对话信息等都限制的比较小(不然实验实在是跑得太慢了),因此实验结果整体会较原论文中偏低。

最终的实验设置和结果如下:

  • Dataset: Ubuntu V2

  • Device: GTX 1060 6G x1

  • Pretrained model: BERT-small-uncased (https://storage.googleapis.com/bert_models/2020_02_20/all_bert_models.zip)

  • Batch size: 32

  • max_contexts_length: 128

  • max_context_cnt: 4

  • max_response_length:64

  • lr: 5e-5

  • Epochs: 3


Results:

▲ 复现实验结果汇总


从上表中明显可以看出,Poly-encoder 的效果要远优于 Bi-encoder 的,当使用 16 个 codes 时,poly 较 bi 的提升可得到 2.24 个点,而使用 64、360 个 codes 时提升分别可达 3.12 和 3.52 个点。而且模型的训练速度几乎没有受到影响,同时对显存的负担也非常小。



总结


本文提出的 Poly-encoder 思路非常清晰,实现难度不高,而且实验效果非常理想,我个人非常喜欢!

Poly-encoder 架构还有一个突出优点在于,它可以很轻松地拓展到大量信息检索相关的领域,无论是搜索、推荐,或是 CV 领域的 ReID 等,只要可以产生 query 和 candidates 的向量 vec_q 和 vec_c,那么都有可能成功应用 Poly-encoder。

我自己十分看好 Poly-encoder,相信在未来它会成为和 DSSM 一样的经典必读论文。




点击以下标题查看更多往期内容: 




#投 稿 通 道#

 让你的论文被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。


📝 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志


📬 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



关于PaperWeekly


PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。



登录查看更多
6

相关内容

基于多头注意力胶囊网络的文本分类模型
专知会员服务
76+阅读 · 2020年5月24日
专知会员服务
44+阅读 · 2020年3月6日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
49篇ICLR2020高分「图机器学习GML」接受论文及代码
专知会员服务
61+阅读 · 2020年1月18日
BERT进展2019四篇必读论文
专知会员服务
67+阅读 · 2020年1月2日
六篇 EMNLP 2019【图神经网络(GNN)+NLP】相关论文
专知会员服务
71+阅读 · 2019年11月3日
【论文笔记】基于LSTM的问答对排序
专知
12+阅读 · 2019年9月7日
论文浅尝 | 基于多模态关联数据嵌入的知识库补全
开放知识图谱
12+阅读 · 2018年12月13日
如何匹配两段文本的语义?
黑龙江大学自然语言处理实验室
7+阅读 · 2018年7月21日
论文解读 | 基于递归联合注意力的句子匹配模型
论文浅尝 | 基于RNN与相似矩阵CNN的知识库问答
开放知识图谱
8+阅读 · 2018年5月29日
论文浅尝 | 利用 RNN 和 CNN 构建基于 FreeBase 的问答系统
开放知识图谱
11+阅读 · 2018年4月25日
从Encoder到Decoder实现Seq2Seq模型(算法+代码)
量化投资与机器学习
8+阅读 · 2017年7月9日
Arxiv
3+阅读 · 2017年8月15日
VIP会员
相关VIP内容
基于多头注意力胶囊网络的文本分类模型
专知会员服务
76+阅读 · 2020年5月24日
专知会员服务
44+阅读 · 2020年3月6日
Transformer文本分类代码
专知会员服务
116+阅读 · 2020年2月3日
49篇ICLR2020高分「图机器学习GML」接受论文及代码
专知会员服务
61+阅读 · 2020年1月18日
BERT进展2019四篇必读论文
专知会员服务
67+阅读 · 2020年1月2日
六篇 EMNLP 2019【图神经网络(GNN)+NLP】相关论文
专知会员服务
71+阅读 · 2019年11月3日
相关资讯
【论文笔记】基于LSTM的问答对排序
专知
12+阅读 · 2019年9月7日
论文浅尝 | 基于多模态关联数据嵌入的知识库补全
开放知识图谱
12+阅读 · 2018年12月13日
如何匹配两段文本的语义?
黑龙江大学自然语言处理实验室
7+阅读 · 2018年7月21日
论文解读 | 基于递归联合注意力的句子匹配模型
论文浅尝 | 基于RNN与相似矩阵CNN的知识库问答
开放知识图谱
8+阅读 · 2018年5月29日
论文浅尝 | 利用 RNN 和 CNN 构建基于 FreeBase 的问答系统
开放知识图谱
11+阅读 · 2018年4月25日
从Encoder到Decoder实现Seq2Seq模型(算法+代码)
量化投资与机器学习
8+阅读 · 2017年7月9日
Top
微信扫码咨询专知VIP会员