题目: 图神经网络在分布外图上的泛化 论文链接: http://www.shichuan.org/doc/157.pdf 论文代码: https://github.com/googlebaba/StableGNN
近日,北京邮电大学GAMMA Lab与清华大学合作的论文“Generalizing Graph Neural Networks on Out-Of-Distribution Graph”被人工智能顶级期刊IEEE TPAMI (影响因子23.6)接收,图神经网络的分布外泛化能力决定了其在实际应用中的稳定性,是近年来的研究热点,该论文的初始版本于2021年11月放于arXiv(https://arxiv.org/abs/2111.10657),是早期将因果方法与图神经网络结合解决图分布外泛化问题的文章之一。本文将介绍图分布外泛化的关键问题、解决方法、以及未来研究工作。
目前提出的图神经网络(GNN)方法没有考虑训练图和测试图之间的不可知偏差,从而导致GNN在分布外(OOD)图上的泛化性能变差。导致GNN方法泛化性能下降的根本原因是这些方法都是基于IID假设。在此条件下,GNN模型倾向于利用图数据中的虚假相关进行预测。但是,这样的虚假相关可能在未知的测试环境中改变,从而导致GNN的性能下降。因此,消除虚假相关的影响对于实现稳定的GNN模型至关重要。为了实现此目的,在本文中,我们强调对于图级别任务虚假相关存在于子图级别单元,并且用因果视角来分析GNN模型性能下降的原因。基于因果视角的分析,我们提出了一个统一的因果表示框架用于稳定GNN模型,称之为StableGNN。这个框架的主要思想是首先利用一个可微分的图池化层提取图的高层语义特征,然后借助因果分析的区分能力来帮助模型摆脱虚假相关的影响。因此,GNN模型可以更加专注于有区分性的子结构和标签之间的真实相关性。我们在具有不同偏差程度的仿真数据和8个真实的OOD图数据上验证了我们方法的有效性。此外,可解释性实验也验证了StableGNN可以利用因果结构做预测。 本质上说,对于一般的机器学习方法,当遭受分布偏移问题时,准确率下降的根本原因是不相关特征和类别标签之间的虚假相关导致的。这个虚假相关根本上是由不相关特征和相关特征的意外的相关性导致的。而对于本文研究的图级别任务,由于图的性质通常由子图单元决定 (比如,在分子图中,原子和化学键团表示其功能单元),所以我们定义一个子图单元可以是一个对于标签相关的或者不相关的特征单元。如图1所示,以’‘房子’‘模体分类任务为例,其中图的标签表示一个图中是否有“房子”模体。GCN模型是在“房子”模体和“星星”模体高度相关的训练图上训练的。在这个数据上,“房子”模体和“星星”模体将会高度相关。这个意料之外的相关性将会导致“星星”模体的结构特征和“房子”标签的虚假相关。图1的第二列展示了用于GCN预测的最重要的子图可视化结果 (由GNNExplainer产生)。由结果可知,GNN倾向于利用星星模体做预测。然而当遭遇没有“星星”模体的图,或者其他模体(比如,"钻石"模体)和星星模体在一起时, GCN模型被证明容易产生错误的结果。
图1 "房子"模体分类例子 为了去除虚假相关对于GNN模型泛化性的影响,我们提出了一个新颖的用于图的因果表示框架,称之为StableGNN, 其结合了GNN模型灵活的表示学习和因果学习方法对于区分虚假相关能力的两方面优势。对于表示学习部分,我们提出了一个图高层语义学习模块,其利用了一个图池化层来映射相近的节点为簇,其中每一个簇为原始图中一个紧密连接的子图单元。此外,我们理论证明了不同图的簇的语义含义可以通过一个有序的拼接操作实现匹配。给定了匹配的高层语义变量,我们用因果视角分析GNN的性能退化并且提出了一个新颖的因果变量区分正则化项通过学习一套样本权重来去除每一个高维变量对之间的相关性。这两个模块在我们的模型中联合训练。此外,如图1所示,StableGNN可以有效的排除不相关子图的影响(“星星”模体)并且利用真实的相关子图("房子"模体)做预测。
所提出框架的基本想法是设计一个因果表示学习方法来抽取有意义的图高层语义变量然后估计他们对于图级别任务的真实因果效应。如图2所示,所提出的模型框架主要分为两个部分:图高层语义表示学习模块和因果变量区分模块。
图2 StableGNN的模型框架
高层变量池化 为了学习节点表示同时映射紧密连接的子图到几个簇中,我们采用DIffPool 层学习一个簇分配矩阵来实现这个目标。具体而言,由于GNN可以平滑节点表示同时使得表示更加具有区分性,给定输入邻接矩阵, 同时表示为, 和节点属性 ,我们首先使用一个embedding GNN模块来得到平滑的节点表示: 其中 是一个3层的GNN模块,GNN层可以是GCN,GraphSAGE或者其他的。然后我们构建了一个池化层基于这个平滑的特征表示。具体地,我们产生一层的节点表示如下: 由于在相同子图的节点可能会有相似的表示和邻居结构,GNN模型可以把具有相似特征和结构的节点映射到相同的表示上,因此我们把节点表示 和邻接矩阵 输入到一个池化GNN模块来产生第一层的的簇分配矩阵: 其中 和 表示第i个节点的簇分配向量, 同样是一个3层的GNN模块,其输出维度为预设的在层1的最大的簇个数,是一个参数,可以通过端到端的方法学习到。在得到分配矩阵 之后,我们可以知道每个节点在预设的簇上的分配概率。因此,新的节点的表示可以通过如下方法计算: 这个公式根据簇分配矩阵聚合了节点表示 ,产生了 个簇的表示特征。同样的,我们产生了一个粗化的邻接矩阵如下所示: 我们可以通过堆叠多个Diffpool层来实现层次化的聚类结构。 高层表示匹配 在堆叠L层图池化层之后,我们可以得到最高层的高层簇表示 。由于我们的目标是编码子图结构到图表示中并且 已经显示的编码每个紧密连接的子结构信息到 的行向量中,因此,我们提出使用 的表示矩阵来表示高层语义。然而,由于图数据的非欧特性,对于两个图 和 的第i个学习到的高层表示的 和 的语义含义可能会不同,比如, 可能代表苯环, 可能代表NO2团。为了匹配不同图之间学习到的高层语义表示我们提出通过高层矩阵的行索引来拼接高层变量: 并且我们堆叠在一个mini-batch中的m个图的高维表示得到embedding矩阵: 我们通过GNN层的置换不变性,证明了在经过DiffPool层之后,不同图之间对应的高层语义是匹配的:
到目前为止变量学习部分学习的变量可能是虚假相关,在本节中,我们首先分析以因果视角分析导致GNN模型性能下降的原因,然后提出一个因果变量区分正则化器(CVD)。 以因果视角重视GNN
我们的目标是学习到一个分类器 基于相关的特征Z。为了达到这个目的,我们需要区分学习到的表示 哪个是属于稳定特征Z哪个是属于不稳定特征M。Z和M的主要区别是对Y有没有因果效应。对于一个图级别的分类任务,在学习到节点表示之后,他将会被送到一个分类器里来预测他的标签。这个预测过程可以表示为图3(a),其中 T 是处理变量,Y是输出预测值,X是混淆变量。路径 表示GNN 的目标是估计一个学习到的表示变量T到Y 的因果效应。同时其他变量将会被视为混淆变量X。由于子图之间存在虚假相关,因此他们学习到的表示之间也存在虚假相关。因此,存在一个在X和T之间的路径。并且由于GNN同样也使用confounder做预测,所以存在一条路径 。因此,这两条路径形成一个X到T的后门路径(i.e., ),从而导致T和Y之间的虚假相关。这个虚假相关将会改变处理变量和标签的之间真实相关性,并且在测试的时候会改变。在这种情景下,目前的GNN方法不能准确的评估子图的因果效应,因此GNN的性能可能会衰减。混淆平衡技术通常被用来评估变量的因果效应,但是他们通常针对某一变量是由单个维度的特征组成的数据,我们要处理的数据是是多个高维变量组成的,因此,我们提出一种多变量多维度的混淆变量平衡技术,如图3b所示: 其中 是第k个混淆变量的embedding矩阵。
图 3 GNN的因果视角 重加权HSIC
但是上述的混淆变量平衡技术主要针对的是二元处理变量,我们需要处理的高维处理变量。基于混淆平衡技术主要目的是去除处理变量和混淆变量之间的关联,我以我们考虑采用HSIC来度量高维变量之间的关联,同时提出采用样本加权的方式去除高维变量之间的关联,方法如下:对于两个变量U和V,我们首先采用随机初始化的样本权重来重加权它们: 然后我们可以得到加权的HSIC: 对于去除所有变量之间的相关性,我们优化如下的全局高维变量去相关项: 2.3 加权的GNN模型
在传统的GNN模型中,对于每一个图数据的权重是相等的。因为学习到的样本权重 可以全局的去除高层变量之间的相关性,我们提出使用这个样本权重来重加权GNN的损失,并且迭代的优化样本权重 和加权GNN的参数如下所示:
我们分别在仿真数据上和真实数据上验证了我们的实验效果。
我们通过控制“房子”模体和“星星”模体的相关性程度,生成了{0.6,0.7,0.8,0.9}四种偏差程度不同的训练数据,更多生成数据的细节请参考原文。我们分别以GCN/GraphSAGE作为基模型实现了我们的模型,所以本节主要和相应的基模型进行了对比。实验结果如表1所示。首先,相较于基模型我们都取得了比较大的提升效果,证明了我们是个有效的框架。其次,在偏差程度越大的数据上提升效果越明显,证明了我们的方法可以有效对抗数据偏移产生的分布外效果下降的问题。最后,我们的模型相较于GCN/GraphSAGE都有明显的提升,证明了我们的方法是一个灵活的框架可以提升现有模型的效果。图4是一些可解释性的例子,也能很好的说明我们的模型可以利用因果结构进行预测。
图4 GCN和StableGCN的可解释性例子
我们在OGB的七个分子图性质预测的数据上展开实验,与常用的数据不同的是,这些数据都采用scaffold splitting 从而使得具有不同结构的图数据划分到训练集和测试集。此外,我们还采用了常用的MUTAG数据集用于解释我们的结果。表2是数据集的统计信息。表3是实验结果。从表中可以看出,我们的方法综合性能排在前两位,远远大于排名第3的方法,证明了现有GNN方法在OOD场景下的图预测任务上表现的都不好而我们的方法可以取得较好的结果。同时在不同类型数据,不同的任务的数据集上我们都取得了较好的效果,证明了我们的方法是一个通用的框架。
图4是MUTAG数据集上的可解释性实验。蓝色,绿色,红色和黄色分别代表N,H,O,C原子。由GNNExplainer产生的最重要的子图被标为黑色。StableGNN正确的确定了功能团NO2和NH2,这些功能团被认为是对Mutagenic 有决定性作用的,而其他方法不能找到有解释性的子图做预测。
图 4 MUTAG数据集上的可解释性实验
在本文中,我们首次研究了图数据在OOD上的泛化问题。我们以因果视角分析了这个问题,认为子图之间的虚假相关会影响模型的泛化性。为了提高现有模型的稳定性,我们提出一个一般化的因果表示学习框架,称之为StableGNN,其有效的结合图高层表示学习和因果效果评估到一个统一的框架里。丰富的实验很好的验证了StableGNN的有效性,灵活性,和可解释性。 此外,我们认为本文开启了一个在图数据上进行因果表示学习的方向。本文的最重要的贡献是提出了一个通用的因果表示框架:图高层变量表示学习和因果变量区分,这两个部分都可以为任务而特殊的设计。比如,多通道的滤波器可以被用来学习图上的不同的信号到子空间里。然后对于一些数据也许在高层变量之间存在这更复杂的因果结构,因此发现这些因果结构对于重构原始数据生成过程将会更有效。
[1] Shaohua Fan, Xiao Wang, Chuan Shi, Peng Cui, Bai Wang. Generalizing Graph Neural Networks on Out-Of-Distribution Graphs. IEEE TPAMI 2023 [2]R. Ying, D. Bourgeois, J. You, M. Zitnik, and J. Leskovec, “Gnnex-plainer: Generating explanations for graph neural networks,” NeurIPS, 2019. [3] X. Zhang, P. Cui, R. Xu, L. Zhou, Y. He, and Z. Shen, “Deep stablelearning for out-of-distribution generalization,” CVPR, 2021, pp.5372–5382 [4]R. Ying, J. You, C. Morris, X. Ren, W. L. Hamilton, and J. Leskovec,“Hierarchical graph representation learning with differentiablepooling,” NeurIPS, 2018. [5] B. Schölkopf, F. Locatello, S. Bauer, N. R. Ke, N. Kalchbrenner,A. Goyal, and Y. Bengio, “Toward causal representation learning,”Proceedings of the IEEE, vol. 109, no. 5, pp. 612–634, 2021.