导读:在 NeurIPS 2020 上,清华大学联合微众银行、微软研究院以及博世人工智能中心提出了 Graph Random Neural Network (GRAND),一种用于图半监督学习的新型图神经网络框架。在模型架构上,GRAND 提出了一种简单有效的图数据增强方法 Random Propagation,用来增强模型鲁棒性及减轻过平滑。基于 Random Propagation,GRAND 在优化过程中使用一致性正则(Consistency Regularization)来增强模型的泛化性,即除了优化标签节点的 cross-entropy loss 之外,还会优化模型在无标签节点的多次数据增强的预测一致性。GRAND 不仅在理论上有良好的解释,还在三个公开数据集上超越了 14 种不同的 GNN 模型,取得了 SOTA 的效果。
这项研究被收入为 NeurIPS 2020 的 Oral paper (105/9454)。
论文名称:GraphRandom Neural Network for Semi-Supervised Learning on Graphs
ArXiv: https://arxiv.org/abs/2005.11079
Github: https://github.com/THUDM/GRAND
图是用于建模结构化和关系数据的一种通用的数据结构。在这项工作中,我们重点研究基于图的半监督学习问题,这个问题的输入是一个节点带属性的无向图,其中只有一小部分节点有标签,我们的目的是要根据节点属性,图的结构去预测无标签节点的标签。近几年来,解决这个问题一类有效的方法是以图卷积神经网络(GCN)[1] 为代表的图神经网络模型(GNN)。其主要思想是通过一个确定性的特征传播来聚合邻居节点的信息,以此来达到对特征降噪的目的。
但是,最近的研究表明,这种传播过程会带来一些固有的问题,例如:
1) 过平滑,图卷积可以看做是一种特殊形式的拉普拉斯平滑,叠加多层之后节点之间的 feature 就会变得不可区分。
2)欠鲁棒,GNN 中的特征传播会使得节点的预测严重依赖于特定的邻居节点,这样的模型对噪音的容忍度会很差,例如 KDD’18 的 best paper [2] 就表明我们甚至可以通过间接攻击的方式通过改变目标节点邻居的属性来达到攻击目标节点的目的。
3)过拟合,在半监督节点分类的任务中,有标签的节点很少,而一般 GNN 仅仅依靠这些少量的监督信息做训练,这样训练出来的模型泛化能力会比较差。
为了解决这些问题,在这个工作中我们提出了图随机神经网络(GRAND),一种简单有效的图半监督学习方法。与传统 GNN 不同,GRAND 采用随机传播 (Random Propagation)策略。具体来说,我们首先随机丢弃一些节点的属性对节点特征做一个随机扰动,然后对扰动后的节点特征做一个高阶传播。这样一来,每个节点的特征就会随机地与其高阶邻居的特征进交互,这种策略会降低节点对某些特定节点的依赖,提升模型的鲁棒性。
除此之外,在同质图中,相邻的节点往往具有相似的特征及标签,这样节点丢弃的信息就可以被其邻居的信息补偿过来。因此这样形成的节点特征就可以看成是一种针对图数据的数据增强方法。基于这种传播方法,我们进而设计了基于一致性正则(consistency regularization)的训练方法,即每次训练时进行多次 Random Propagation 生成多个不同的节点增强表示,然后将这些增强表示输入到一个 MLP 中,除了优化交叉熵损失之外,我们还会去优化 MLP 模型对多个数据增强产生预测结果的一致性。这种一致性正则损失无需标签,可以使模型利用充足的无标签数据,以弥补半监督任务中监督信息少的不足,提升模型的泛化能力,减小过拟合的风险。
我们对 GRAND 进行了理论分析,分析结果表明,这种 Random propagation + Consistency Regularization 的训练方式实际上是在优化模型对节点与其邻居节点预测置信度之间的一致性。
我们在 GNN 基准数据集中的实验结果对 GRAND 进行了评测,实验结果显示GRAND 在 3 个公开数据集中显著超越了 14 种不同种类的 GNN 模型,取得了 SOTA 的效果。实验结果(图三):
图三
此外我们还对模型泛化性,鲁棒性,过平滑等问题进行了分析,实验结果显示 1)Consistency Regularization 和 Random Propagation 均能提升模型的泛化能力(图四);2)GRAND 具有更好的对抗鲁棒性(图五);3)GRAND 可以减轻过平滑问题(图六)。
图四
图五
图六
参考文献:
[1] Kipf T N, Welling M. Semi-supervised classification with graph convolutional networks[J]. arXiv preprint arXiv:1609.02907, 2016.
[2] Zügner D, Akbarnejad A, Günnemann S. Adversarial attacks on neural networks for graph data[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2018: 2847-2856.
喜欢本篇内容,请分享、点赞、在看