域泛化(Domain Generalization)中有很多工作是用 meta learning 做的。Meta learning 在 few shot 中很常用,它的目的也是提升模型的泛化性,所以我们来看看 DG 中采用 meta learning 的工作。
Revisit Meta Learning
Meta learning 的motivation 就是让模型学会学习。一个学会了如何学习的模型,自然就有好的泛化性。 以 few shot learning 背景为例,我们只有少量的样本来训练一个任务。直接用少量的样本训练模型显然会过拟合,那怎么办?Meta learning 给出的策略就是采用公用大型数据集和已有的少样本共同训练模型。它将数据集分成两类,大型数据集的样本称为 support sets,少样本称为 query sets;将训练分成两个阶段,一次学习称为一个 epoch(整个数据集),首先在 support sets 上训练并更新一次梯度,接着用 query sets 基于 support sets 更新的模型再求一次梯度,本轮 epoch 的梯度更新与 query sets 上梯度更新方向一致。 可以这么理解,support sets 的作用就是让模型有一个好的初始化,接着再用 query sets 对模型进行 fine-tune,使模型真正适用于任务场景。显然,大型数据集和拥有的少样本数据来自不同 domain,存在 distribution shift,大型数据集训练的模型在任务上只能得到次优的效果。而通过一次次 query sets 的"fine-tune",模型就能很好地适应任务场景。 这么一看,是不是跟 DG 要做的很像?所以,DG 也这么干了。但是 DG 的场景会更困难一些,因为 DG 在训练时根本不知道目标域数据,就没法用目标域数据作为 query sets。因此 DG 退而求其次的策略是将源域数据划分成 support sets 和 query sets(DG 的论文里一般称为 meta-training sets 和 meta-testing sets),核心依然是模拟 distribution shifts,训练出对 distribution shift robust 的模型,就认为模型拥有了泛化到目标域的能力。
Meta Learning与Domain Alignment对比
Domain Alignment 专注于特征的学习,学到 domain agnostic 的特征。因此它会通过 loss 或者是 domain 判别器等其他各种手段对提取的特征施加约束,认为成功实现分布对齐的模型就是泛化性好的模型。它只是简单通过不同源域的训练数据来模拟 distribution shift。 Meta learning 主要是对输入数据的设计,强调数据的 distribution shift,并通过两次梯度更新使模型 robust,认为学到 distribution shift 的模型就是泛化性好的模型。但没有对数据作显式对齐。 其实,meta learning 可以看做是一个训练 trick,它可以和所有 DG 方法结合使用。因为 meta learning 对模型结构,loss 都没有任何要求(也称为 model agnostic),只需要对训练数据和训练过程做简单的调整,就可以套在任何模型上了。因此,要是你发现自己的 DG 模型效果不够满意,可以考虑叠加这个 buff(感觉我在教坏人-_-
DG中的Meta Learning
下面就来看几篇 DG 中的论文,了解它们是怎么使用这个 trick 的。
3.1 Meta Learning实现DG
本文给出的方法很简单,但是它对 meta learning 的 insight 做了很好的解释。
论文标题:
Learning to Generalize: Meta-Learning for Domain Generalization
论文链接:
https://arxiv.org/abs/1710.03463
训练时共有 个源域,每次训练采用一个源域作为 meta-testing set,另外的源域作为 meta-training set,得到目标函数: 有意思的点是作者对上述目标函数做 Taylor 展开,得到了以下的形式:
这揭示了目标函数一是要最小化在 meta-training set 和 meta-testing set 上的误差(上式第一第二项),二是使 meta-training set 和 meta-testing set 的优化方向最大程度地相似(上式第三项)。显然,如果目标函数是 ,模型很可能偷懒,找一个容易使该式最小化的源域的梯度方向进行优化,从而过拟合这个源域。而 meta leanring 的目标函数函数加上了这个正则化约束,就促使模型考虑所有源域的梯度方向。因此作者还给出下面两种改进的 meta learning 目标函数,可以替代上式的点积计算相似度。
第一种改进是将点积替换成余弦相似度。第二种是退化为用 meta-training set 的方向优化 meat-testing set,这种方式关键是需要模型有好的初始化。 3.2 解决DG中的Batch Normalization问题
论文标题:
MetaNorm: Learning to Normalize Few-Shot Batches Across Domains
Recall that in this setting, we have access to target labeled data for only half of our categories. We use soft label information from the source domain to provide information about the held-out categories which lack labeled target examples.
一个好的特征空间自然是不同 domain 的数据尽量混在一起难以区分,不同 class 的数据尽量形成良好的聚簇。作者就此分别对语义空间和特征空间采用了不同的操作。 首先是语义空间。对于每个 domain,计算特征空间中属于同一 class 的样本的均值,作为这个 class 的 'concept',并通过 softmax 得到这个 class 的软标签。 接着聚合同一个 domain 的所有软标签向量,得到软标签混淆矩阵。我们希望训练过程中不同 domain 的 inter-class 关系能够被保持,因此操作还是进行 domain 的对齐,也就是最小化不同 domain 混淆矩阵的对称 KL 散度。 接着是特征空间对齐。同样是借鉴对比损失的思想,计算下面的 triplet loss,使 positive sample 与 anchor 的距离小于 negative sample 与 anchor 的距离。 本文的训练数据同样被分为 meta-training set 和 meta-testing set 来模拟 distribution shift。
总结 Meta learning 就是通过对已有的数据作简单的划分模拟 distribution shift,使模型学得更 robust。它是一种训练的思路,可以和任何 DG 的模型结构结合来增强泛化性。 但是 meta learning 同样存在一些缺陷。一是虽然可能训练得到的模型对 distribution shift 不那么敏感,但仍不能避免模型对源域数据过拟合。二是模型每一层更新都要求两次梯度,计算效率自然会慢。