Soft prompt learning has recently emerged as one of the methods of choice for adapting V&L models to a downstream task using a few training examples. However, current methods significantly overfit the training data, suffering from large accuracy degradation when tested on unseen classes from the same domain. To this end, in this paper, we make the following 4 contributions: (1) To alleviate base class overfitting, we propose a novel Language-Aware Soft Prompting (LASP) learning method by means of a text-to-text cross-entropy loss that maximizes the probability of the learned prompts to be correctly classified with respect to pre-defined hand-crafted textual prompts. (2) To increase the representation capacity of the prompts, we propose grouped LASP where each group of prompts is optimized with respect to a separate subset of textual prompts. (3) We identify a visual-language misalignment introduced by prompt learning and LASP, and more importantly, propose a re-calibration mechanism to address it. (4) We show that LASP is inherently amenable to including, during training, virtual classes, i.e. class names for which no visual samples are available, further increasing the robustness of the learned prompts. Through evaluations on 11 datasets, we show that our approach (a) significantly outperforms all prior works on soft prompting, and (b) matches and surpasses, for the first time, the accuracy on novel classes obtained by hand-crafted prompts and CLIP for 8 out of 11 test datasets. Code will be made available at https://www.adrianbulat.com/lasp
翻译:软提示学习近来成为了V&L模型适应下游任务的选择之一,仅使用少量的训练样本即可完成。然而,当前的方法在测试相同领域中的未知类别样本时,存在严重的过度拟合问题,导致准确率显著降低。
为此,本文提出了以下 4 点贡献:(1) 为了减轻基类别过拟合问题,我们采用了一种新颖的语言感知软提示(Language-Aware Soft Prompting,LASP)学习方法,采用基于文本交叉熵损失函数的方法将学习到的提示与预定义的手工提示进行分类并优化。(2) 为了提高提示的表示能力,我们提出了一种分组LASP方法,每个提示组针对不同的文本提示子集进行优化。(3) 我们确定了提示学习和LASP引入的视觉-语言不对准问题,并采用重新校准机制进行解决。(4) 我们表明LASP本质上适用于包括虚拟类别(即没有可用视觉样本的类别名称)的训练,进一步提高了学习提示的鲁棒性。
通过对 11 个数据集的评估,我们证明了我们的方法(a) 在软提示上显著优于所有之前的工作,以及 (b) 在 8 个测试数据集中,与手工提示和 CLIP 制作出来的准确率相匹配或超过。我们将此方法的代码放在 https://www.adrianbulat.com/lasp 上以供使用。