从ICML 2022看域泛化(Domain Generalization)最新进展

2022 年 8 月 13 日 PaperWeekly


©PaperWeekly 原创 · 作者 | 张一帆

单位 | 中科院自动化所博士生

研究方向 | 计算机视觉


Domain Generalization(DG:域泛化)一直以来都是各大顶会的热门研究方向。DA 假设我们有多个个带标签的训练集(源域),这时候我们想让模型在另一个数据集上同样表现很好(目标域),但是在训练过程中根本不知道目标域是什么,这个时候如何提升模型泛化性呢?核心在于如何利用多个源域带来的丰富信息。ICML 2022 域泛化相关的文章来研究最新的进展。





DNA


论文标题:
DNA: Domain Generalization with Diversified Neural Averaging

论文链接:

https://proceedings.mlr.press/v162/chu22a.html


目前绝大多数 DG 方法都基于一个不切实际的假设,即训练时的 hypothesis space 包含一个最优分类器,因此源域与目标域的的联合损失可以降到最小。这个假设对神经网络来说是非常难以满足的。当对源数据进行训练时,分类器倾向于只记住所见训练数据的鉴别特征,而忘记任何其他信息,包括那些可能是 target domain 分类所需要的的信息。训练过程中目标数据的不可达性意味着深度分类器的假设空间倾向于支持低源域风险的子空间,而不一定支持低目标域风险的子空间。总之,理想分类器可能脱离训练阶段假设空间。

解决这个问题的一个方法是 classifier ensemble,即对分类器进行集成。本文从理论和实验角度讨论了 ensemble 与 DG 任务的 connection。

理论上,本文首先引入一个剪枝的 jensen - shannon(PJS)损失,证明了 ρ 集合(由 quasi-posteriorρ 加权的平均分类器)在目标域上的 PJS 损失受 Gibbs 分类器的在源域的平均平方根风险的限制,即前者被后者 bound。通过对分类器集合的多样性进行约束,得到了一个更紧密的 DG bound。根据这个 bound,本文提出了 diversified neural averaging(DNA)method。

接下来简单介绍该方法,首先本文引入的 PJS 损失如下:



这个损失是基于 PJS divergence 提出的。



PJS 继承了 JS divergence 的优秀性质,比如三角不等式,除此之外,其次,PJS 的平方根对分类器而言具有凸性,这是原 JS 散度不具备的特性。本文提出的 bound 形式如下:



其中第一项源域与目标域  的 JS 散度是无法测量的,视作常数,最后一项也视作常数,最后模型的算法如下所示。与 ERM 的不同之处在于,这里会随机 sample 个分类器然后分别计算 PJS 损失,最后减去各个分类器 PJS 损失方差的平均作为正则约束。最后用作 inference 的模型会在每个训练阶段进行更新。



最终的实验结果如下所示,classifier ensemble 能带来相当不错的提升。





MAPLE

论文标题:
Model Agnostic Sample Reweighting for Out-of-Distribution Learning

论文链接:

https://proceedings.mlr.press/v162/zhou22d.html


本文的关键思想是找到一种有效的训练样本加权方式,以便在加权训练数据上对大型模型进行标准的经验风险最小化训练,从而获得更好的 OOD 泛化性能。

为了防止模型依赖于 spurious correlation,对其进行正则化是最常用的方法,常见的策略包括 distributionallyrobust optimization(DRO)以及 IRM。DRO 的目标是在与原始训练分布一定距离内的一组分布中优化最差情况的性能,而 irm 则试图学习一种丢弃虚假特征的不变表示。

DRO 和 IRM 由于在小模型和数据集上具有良好的性能。但是应用于过参数化的深度神经网络中却不太理想,其中主要的原因是过参数化的 DNN 可以很容易地将 DRO 或 IRM 的正则化项降至零,同时仍然依赖于伪特征,即所谓的 over-fitting。

另一条研究路线是基于包括重要性抽样在内的重新加权,即首先对样本进行重新加权,然后再加权样本上根据 ERM 进行训练。因为加权的过程与 DNN 的模型大小无关,因此这类方法不像 DRO 和 IRM 那样存在模型过参数化导致过拟合的漏洞。然而,这些基于重加权的方法中的需要更严格的先验知识,比如域注释,才能很好地执行,这使得它们在实践中与基于正则化的方法相比缺乏竞争力。

本文想要将这两种方法进行结合,同时利用他们的优势:


具体的做法如上所示,给定源域和目标域,从源域分离出一个验证集 。内层优化就是在加权数据集上进行 ERM 的训练,然后外层优化去 优化权重优化权重使得 OOD risk 变得最小 ,这里的 可以是任意的 OOD risk。具体实现如下所示:



该算法在不需要 domain label 的情况下取得了 SOTA 的效果。





SparseIRM


论文标题:

Sparse Invariant Risk Minimization

论文链接:

https://proceedings.mlr.press/v162/zhou22e.html


IRM 是这两年流行起来的一种 OOD 问题的新范式,IRM 的关键思想是学习从多个环境中提取的数据集上的不变特征表示,基于这种表示,人们应该能够学习在所有这些环境中工作良好的通用分类器。由于模型在这些现有环境中取得了一致的良好性能,可以预期在具有看不见的分布转移的新环境中也具有良好的泛化能力。然而 IRM 大多数时候只在小数据集或者小模型上有用,对于过参数化的神经网络而言往往表现不佳。

本文从理论上证明,当过参数化时,与 ERM 可以有良好或更好的泛化性不同, IRM 甚至在简单的线性情况下也可能失败。可以预见,在过度参数化的深度神经网络中,IRM 很容易崩溃,因为参数比简单的线性模型多得多。

本文提出了一个简单而有效的稀疏不变风险最小化(SparseIRM)范式来解决上述矛盾。其中关键思想是利用全局稀疏性约来防止伪特征(spurious correlation)在整个 IRM 过程中泄漏到我们所研究的子模型中。该范式成功地在整个训练过程中通过稀疏约束对伪特征和随机特征设置了障碍,从而获得了更好的泛化性能。

具体来说,在训练过程中,由于稀疏性约束导致所使用的子网络很小,不能包含所有的虚假和随机特征,因为这些特征的数量总是明显大于不变性特征。因此,网络需要识别和关注不变特征,使损失函数最小化。文章通过一个简单的线性情况的理论分析提供了对这一现象的理解。

首先我们简单对 IRM 和神经网络的稀疏性做介绍。
3.1 IRM
假设我们有训练数据集 ,encoder 和 classifier 分别计作 ,训练参数整体计作 ,IRM 通过如下的目标对参数进行优化。



第一项是各个训练环境的 empirical loss 的和。第二项是正则化项,促使分类器 在各个环境中都是最优的。目前比较有代表性的是如下两种约束



第一个是对各个环境中针对分类器的损失的 2 模进行约束,第二项是对各个环境损失的方差进行约束。
3.2 模型稀疏性
近年来,在神经网络中引入了稀疏性,以提高 推理效率 减少模型大小 。关键思想是通过开发一些适当的剪枝规则,在训练期间或之后从神经网络中识别并移除不重要的权重。最典型的规则是基于权重大小。

实证结果表明,该方法可以以较小的速度减小模型的大小,显著提高推理效率甚至性能上的损失可以忽略不计。这使得在计算和内存预算有限的设备上部署现代 dnn 成为可能。现有的大多数方法都是针对在 I.I.D. 场景下由 ERM 训练的神经网络开发的。这项工作将稀疏性引入到 IRM 训练中,以提高泛化性能。。

3.3 IRM遇见DNN
虽然 IRM 在小数据集上效果很好,并且有良好的理论性质保证,但是当模型的参数量增大导致过拟合时,IRM 的效果会下降非常多,如下图所示。


3.4 SparseIRM Framework
下图是本文方法的一个简单说明,SparseIRM(上)和 sparsify-after-training MRM(下)的流程图。填充灰色的块表示选中的特征,未填充的块表示未选中的特征。



该方法核心思想是在整个训练过程中采用稀疏性约束作为防御,以防止虚假和随机特征泄露到我们所研究的子网中。相比于传统的先训练模型再做稀疏化,本文训练时同时进行不变风险最小化和稀疏训练。直观上,在训练过程中,由于稀疏性约束,我们所工作的子网太小,无法包含所有的虚假和随机特征,因为这些特征的数量总是明显大于不变特征。因此,为了实现更小的损失,网络必须识别和关注不变特征。本文采用了最新的稀疏训练方法来解决稀疏不变风险最小化问题。

将上述目标公式化为如下形式:



即给参数 增加一个 mask ,然后使用超参数 来控制模型的大小。因为这里 是离散的,所以实现上需要做如下的改变:



即将 mask 建模为一个高维的伯努利分布然后使用带 Gumbel-Softmax 的 SGD 对其进行优化。

实验效果上来看该方法的表现非常不错,在 Coloredmnist 相关的数据集上比 IRM 强出了一大截,同时由于对 overfitting 具有不错的抗性,该方法也更有希望在更大更复杂的数据集上表现的更好。





SDAT


论文标题:
A Closer Look at Smoothness in Domain Adversarial Training

论文链接:

https://arxiv.org/abs/2206.08213

Domain adversarial training,即对神经网络进行对抗学习,使得他对 domain 具有不变性。DANN 这篇文章目前被引量已经超过 4000,该技术被广泛应用于目标识别,目标分类,fairness,域泛化,域自适应和图像翻译等任务。传统的 DAT 如下所示,它使用一个额外的域分类器来分类源域和目标域,然后反传梯度的时候取反,依此达到 feature extractor 对域不变的要求。



训练过程由两个 loss 来 lead,一个是对抗损失,即图中的 adversarial loss,类似于 GAN 的 loss,一个是任务损失,即传统的分类或者回归损失。目前对 DAT 优化的性质进行明确的分析和改进的研究还很少。在优化相关文献中,其中一个方向集中于开发收敛到平滑(或平坦)最小值的算法,但是,本文发现这些技术直接应用于 DAT 时,并没有显著改善目标域上的泛化效果。
4.1 问题定义
本文数学上的 formulation 基于域泛化,其所使用的 loss 为传统的 DAT 的损失:
其中 是域分类损失,整体的问题形式是 min-max 的。
4.2 问题分析
本文的分析从 empirical loss 的海森矩阵入手,即  也就是分类/回归损失的海森矩阵。由于神经网络运行在过度参数化的模型中,在训练数据损失低并不意味着泛化性能好。有很多工作显示了在最小值附近的平滑性很大程度上影响了模型的泛化,本文也是从平滑性入手进行分析。为了衡量海森矩阵的平滑性,本文选择了矩阵的迹  以及矩阵最大的特征值  作为指标,通过实验上的简单验证, 本文发现,越低的 会使得 DAT  训练更加稳定,并且在目标域上效果更好


上图展示了  的特征光谱密度图,从左到右依次是使用了 Adam 的 DAT,使用了 SGD 的 DAT,本文提出的 SDAT,以及训练过程中特征值的变化。这里主要的发现即分类损失越平滑,模型在目标域上的表现更好。那么如何才能使得分类器的损失变得更为平滑呢?这也是本文的核心贡献。
4.3 方法

目前已有工作关注于寻找一个局部平滑的最小值点,Sharpness  Aware  Minimization(Sharpness-aware minimization for efficiently  improv-ing generalization)方法使用如下损失来完成这个目标:


他的核心思想是寻找在  的  邻域上都取得低损失的参数。准确的找到内层解是非常困难的,因此 SAM 的优化目标是如下的一阶近似解:


类似的,本文对分类器做了平滑性的约束,具体使用的损失函数如下所示:


这个损失函数存在一个问题,即为什么我们只对分类器做平滑性的约束,而不对域分类器做约束。作者在文中理论性和实验性的验证了,对域分类器做平滑性的约束,可以让域分类损失更小,但是在目标域上的表现却会更差,即对泛化性并没有好处。上述公式可以很方便的集成在目前已有的 DAT 工作中,伪代码如下所示:

4.4 实验
作者在 Office-Home,DomainNet,VisDA-2017 等常用的 DA 数据集,以及目标检测的数据集上都进行了实验,实验结果显示,对于 DAT 系列的方法,在各种不同的 backbone 上(ViT,resnet 等),SDAT 总能带来稳定的提升。


美中不足的是,作者的 baseline 中并没有目前 SOTA 的其他方法,让人比较疑惑 DANN+SDAT 与目前最优方法之间的差距。




总结

近期各个顶会上都涌现出了非常多的 OOD,DG 问题相关的文章,由于 ICML 对理论的喜爱,这些发表于 ICML 的文章绝大多数有着不错的理论性质,而且并不是一些老套的 A+B 的工作,都有不错的 intuition。

本文介绍的第一篇工作 DNA 从理论和实现上将 classifier ensembling 引入 OOD 问题并给出了一个新的 target error upper bound,第二篇文章尝试着将传统的 sample reweighting 方法和 DRO 这类 bi-level optimization 方法的优势进行结合,使用二层优化的方法寻找更好的 reweighting 权重,第三篇文章针对 IRM 最被人诟病的点,即大数据集不 work 来进行,将稀疏性这一点引入 OOD。最后一篇将平滑性约束引入 DAL 从而提升对抗学习在迁移过程中的效果。

四篇文章用了四种不同的工具,总的来看大家都在找不同的切入点来解决 OOD 问题,实际上目前大多数研究可能都与 OOD 有着千丝万缕的联系,还有更多的研究空间等待探索。


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编




🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·

登录查看更多
0

相关内容

分类是数据挖掘的一种非常重要的方法。分类的概念是在已有数据的基础上学会一个分类函数或构造出一个分类模型(即我们通常所说的分类器(Classifier))。该函数或模型能够把数据库中的数据纪录映射到给定类别中的某一个,从而可以应用于数据预测。总之,分类器是数据挖掘中对样本进行分类的方法的统称,包含决策树、逻辑回归、朴素贝叶斯、神经网络等算法。
领域自适应研究综述
专知会员服务
54+阅读 · 2021年5月5日
【CVPR2021】DAML:针对开放领域泛化的领域增广元学习方法
专知会员服务
41+阅读 · 2020年12月1日
【ICML 2020 】小样本学习即领域迁移
专知会员服务
77+阅读 · 2020年6月26日
ICML'21:剑指开放世界的推荐系统
图与推荐
2+阅读 · 2021年12月30日
从ICCV 2021看域泛化与域自适应最新研究进展
PaperWeekly
0+阅读 · 2021年10月28日
国家自然科学基金
5+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
2+阅读 · 2008年12月31日
Arxiv
13+阅读 · 2022年1月20日
Arxiv
13+阅读 · 2021年7月20日
Arxiv
16+阅读 · 2021年7月18日
Arxiv
12+阅读 · 2021年6月29日
Arxiv
13+阅读 · 2021年3月29日
AdarGCN: Adaptive Aggregation GCN for Few-Shot Learning
VIP会员
相关基金
国家自然科学基金
5+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
1+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2012年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
0+阅读 · 2011年12月31日
国家自然科学基金
2+阅读 · 2008年12月31日
相关论文
Arxiv
13+阅读 · 2022年1月20日
Arxiv
13+阅读 · 2021年7月20日
Arxiv
16+阅读 · 2021年7月18日
Arxiv
12+阅读 · 2021年6月29日
Arxiv
13+阅读 · 2021年3月29日
AdarGCN: Adaptive Aggregation GCN for Few-Shot Learning
Top
微信扫码咨询专知VIP会员