图神经网络的局域数据增强方法
Local Augmentation for Graph Neural Networks
本文由腾讯 AI Lab 主导,与斯坦福大学,香港科技大学,宾夕法尼亚州立大学合作完成,提出了一种全新的基于条件生成的图神经网络的数据增强方法,可以作为即插即用的模块嵌入到任意图神经网络的建模流程中,从而显著提高模型性能;适用于药物发现、电商推荐、社交网络等广泛的应用场景。
在图结构的数据及任务上,图神经网络(GNN)已取得了引人注目的性能。GNN的关键设计思路在于通过将每个节点的邻域信息进行聚合,来得到对该节点信息量更为丰富的表征。然而,对于仅有少量邻居的节点,如何将其邻域信息进行有效聚合从而得到最优的表征,目前尚未有定论。
针对该问题,本文提出了一种简单而有效的数据增强方法,局域数据增强,即通过学习邻域节点关于中心节点表征的条件概率分布,生成更丰富的特征,来增强GNN的表达能力。局域数据增强是一个具有广泛适用性的框架,可以被即插即用地嵌入到任意的GNN模型中。本方法从学习到的条件概率中采样得到额外的关于每个节点的特征向量,并作为扩充后的数据用于模型的训练。
通过大量实验和分析,我们证明了本方法可以在多种图结构数据和不同图神经网络上带来一致性的效果提升。举例来说,在Cora,Citeseer和Pubmed数据集上,加入了局域数据增强的图卷积神经网络(GCN)和图注意力网络(GAT)在测试时的平均准确率可以分别提升3.4%及1.6%。此外,在大型图数据集OGB上的实验也证明了,我们的方法相比其他在图的特征、结构层面进行数据增强的方法,在图节点分类任务上具有更优的效果。