©作者 | 牟宇滔
单位 | 北京邮电大学
研究方向 | 自然语言理解
文章来源:
文章链接:
聚焦的问题:最近微调预训练语言模型来捕捉句子 embedding 之间相似性的方法已经取得了 SOTA 的效果,比如 SimCSE。具体地,它们首先定义了一个句子 embedding 的相似性分数(常用的比如余弦相似度),然后利用 NLI 或者 STS 的数据集微调 BERT 模型,这里的句子 embedding 常常是通过 BERT 最后一层平均池化或者直接取 [CLS] token 的 embedding 得到。
1. 作者认为这种方法可解释性不足,通常来说,如果能够从 token 层面找到 cross-sentence 对齐以及计算出各个对齐部分的重要程度是有利于分析句子相似性的,目前方法都没有做到这种 token 层面的显式 cross-sentence 对齐。关于 cross-sentence 对齐,可以通俗理解为找到两个句子间语义互相匹配的 token pairs。
2. 目前的方法采用平均 token embedding 得到句子表示,用于计算句子相似性,这是 sentence-level 层面的建模,无法从 token-level 层面显式地融合语义对齐的 token pairs 之间的距离。
提出的方法:
1. 作者首先提出了一个基于最优传输理论的分析方法,用来分析现有的一些 STS 方法比如 SimCSE。发现目前平均池化+余弦相似度的方法存在传输矩阵 rank=1 的问题,这使模型无法有效地将语义对齐的 token pairs 的相似性融合到整体的句子相似性中。
2. 为了解决上述问题,作者提出了一个基于最优传输的距离度量,RCMD。
3. 此外,作者还提出了 CLRCMD,一个对比学习框架,用于优化句子对的 RCMD,有助于增加句子相似性的质量以及可解释性。
首先简要解释一下最优传输问题以及如何将总传输成本解释为距离度量。
最优传输(OT)问题有三大组件:传输前的状态 d1,传输后的状态 d2,代价矩阵 M。直观的说,最优传输问题就是求解状态 d1 转移到状态 d2,使得总代价最小的函数 T,这个 T 称为最优传输矩阵。
回到 STS 任务,d1 和 d2 可以看成全 1 向量,长度分别为 sentence1 和 sentence2 的 token 长度,代价矩阵的元素 Mij 表示 sentence1 的 token i 和 sentence2 的 token j 的 token embedding 的余弦距离。要求解的最优传输矩阵T就是各个 token pair 对于句子相似性的贡献程度,也就是一个权重矩阵(这里有个限制,行向量和列向量求和都为 1)。
数学形式如下:
优化得到的最优解称为最优传输,总代价称为 earth mover's distance(EMD)或者 Wasserstein 距离:
这个方法主要优势在于分析句子相似性的时候可以显式地将各个 token pairs 的相似性融合进来。
我们将余弦相似度+平均池化的方法表示为一个最优传输问题,并根据得到的传输矩阵来分析该方法的缺陷。
传统的余弦相似度计算公式如下,S1 和 S2 分别由各自的 token embedding 平均池化得到
将余弦相似度转换为余弦距离(1-cos),带入平均池化公式,可以转化为关于 token embedding 的距离度量:
从上面等式(1)的角度来看,这个距离可以被解释为一个传输问题的朴素解,其中代价矩阵和传输矩阵如下:
我们看到,传输矩阵 T 由 token embedding 的 norms 决定。理论上,传输矩阵的秩被限制为 1,这阻碍了 token 距离与 sentence 距离的有效整合。这个分析中指出基于平均池化的相似度不足以有效地捕捉句子之间的 token 对应。
为了解决上述分析的问题,本文引入了一种基于最优传输的新颖的距离度量。首先定义一个考虑上下文嵌入空间中语义相关性的传输问题:
1. 给定来自预训练语言模型的两个句子的 token embedding,我们构建了一个代价矩阵 M(维度为 L1*L2,L1 是句子 1 的 token 长度,L2 是句子 2 的 token 长度),该矩阵用余弦距离编码 token 相似性。
2. 将两个句子的状态向量定义为由各自句子长度归一化的全 1 向量。
我们考虑将该最优传输问题的最优解作为距离度量,称为 contextualized token mover's distance(CMD)。
然而之前关于最优传输理论的研究表明,找到最优传输矩阵 T 的计算复杂度非常高,因此可以考虑将最优传输问题的两个约束条件放松。
假设我们将约束 1 移除,可以推导出如下的最优解:
这种方式得到的传输矩阵只有在语义对齐的 token pairs 的位置会有一个非零值,因此该矩阵的秩大于 1,这意味着它可以表达两个句子之间更复杂的标记级语义关系。(之前平均池化+余弦相似度的方法得到的传输矩阵在所有 token pairs 的位置都可能有非零值,理论上秩等于 1)
此外,这种方法可以显式找到语义对齐的 token pairs,增加了模型的可解释性。
作者进一步提出了一个基于 RCMD 这种新颖距离度量的对比学习框架(CLRCMD),它将 RCMD 距离度量整合到最先进的对比学习框架中。简单地说就是将 SimCSE 的余弦相似度用 RCMD 替换。
这里首先需要把距离转换为相似性分数,RCMD1 和 RCMD2 计算方式相同:
采用这种相似性度量,训练批次中第 i 个句子对的对比学习目标定义如下:
作者认为,这样一个对比学习框架有两个优势:
1. 缓解了微调过程中预训练语义的灾难性遗忘。CLRCMD 更新参数以提高句子相似性的质量,同时没有破坏预训练 checkpoints 中的 token-level 语义。
2. CLRCMD 直接将句子对的相关性提炼成语义对齐的标记对的相关性。从这个意义上说,我们的上下文嵌入空间有效地从训练句子对中捕获了标记级的语义相关性,从而为其句子相似性提供了更好的解释。
特别是对于 STS14、STS15、SICK-R 数据集,CLRCMD-BERTbase 实现了与 SimCSEcls-RoBERTabase 相当的性能,后者的骨干语言模型与 BERTbase 相比使用了 10 倍大的数据进行了预训练。这意味着使用 CLRCMD 的 token-level 监督进行微调可以达到与使用昂贵的预训练 checkpoints 相当的性能。
接下来,我们衡量我们的方法在可解释 STS(iSTS)任务上的性能,以验证 CLRCMD 嵌入了足够水平的可解释性,即使没有任何关于语义对齐块对的监督(即标记的训练数据)。
我们需要找到一种方法来衡量人类判断(句子之间的黄金语义对齐)和所有 token pairs 对句子相似度的贡献之间的一致性。其中一个挑战是将 token pairs 贡献矩阵/传输矩阵转换为 chunk-level 对齐【一个 chunk 包含多个连续 token】。在这里我们使用黄金标准 chunk 信息,也就是说不去识别哪些 token 可以合并为一个 chunk,只聚焦于 chunk 对齐的评估。
使用每种方法获得的对齐块对,我们计算对齐 F1 分数作为评估指标,它表示人类判断与块贡献之间的一致性,结果如下表所示:
1. CLRCMD 有效地突出了语义相关 token-pairs 的贡献并排除了其他贡献;
2. 相反,SimCSEavg 未能代表句子相似性的有意义的令牌级相关性,rank-1 限制阻碍了模型在两个句子之间获得任何合理的对齐,同时它只是一次调整所有可能的标记对的贡献。
本文研究的问题其实是所有句子级任务普遍存在的问题,忽略了 token-level 的可解释性。引入最优传输理论,将 token-level 信息融合到句子表示中,这个思路是比较具有创新性的。提出了一个新的对比学习框架,改变了以往清一色的基于余弦相似度的对比目标,改用一个新颖的 RCMD 作为距离度量,这个比较有意思。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧