论文标题:Graph-Less Neural Networks: Teaching Old MLPs New Tricks Via Distillation
作者:Shichang Zhang, Yozen Liu, Yizhou Sun, Neil Shah
单位:University of California & Snap Inc.
论文链接:https://arxiv.org/abs/2110.08727
代码链接:https://github.com/snap-research/graphless-neural-networks
作者:知乎@张三岁 (转载已获授权)
原文链接:https://zhuanlan.zhihu.com/p/521909814
今天给大家带来的是ICLR 2022的文章Graph-Less Neural Networks: Teaching Old MLPs New Tricks Via Distillation。文章提出了一个非常落地的科研问题,即如何解决GNN的推理时间过长无法在现实世界应用上使用。它选择使用了知识蒸馏(knowledge distillation)的方法来解决这个问题,即把GNNs的知识转换到MLPs上,使得MLP可以达到与GNN相匹敌的性能。这篇文章的模型简单且有效,通过大量的实验从各个角度验证所提模型的有效性。
图神经网络(Graph Neural Networks GNNs)最近在图机器学习(Graph Machine Learning GML)研究中很火,并在节点分类(node classification)任务上表现很好。但是,对于大规模工业界的应用来说,主流的模型仍然是多层感知机(Multilayer Perceptron MLP)(PS: 一般而言,MLP的表现要比GNN差很多)。造成这种学术届和工业届差距的原因之一是GNNs中的图依赖性(graph dependency)。这使得GNNs难以部署在需要快速推理(fast inference)或者对延迟(latency)敏感的应用中。
由图依赖性(graph dependency)引起的邻居节点获取(neighborhood fetching)是导致GNN延迟性高的主要来源之一。对目标节点的推理/预测(inference)需要获取许多邻居节点的拓扑结构(topology)和特征(feature/attribute),所有目前有工作关注在如何做GNN的推理加速(inference acceleration)。常见的推理加速模型有剪枝(pruning)和量化(quantization)。它们是通过减少乘加运算(Multiplication-and-ACcumulation MAC)来在一定程度上加快GNN的推理速度。然而,由于来自GNN本质的图依赖性没有得到解决,剪枝(pruning)和量化(quantization)对于提升GNN的推理速度是有限的(limited)。
与GNNs不同,MLPs因为输入仅是节点的特征(node attribute),所以它对图数据没有图依赖性(graph dependency),且比GNNs更容易部署。同时,MLPs还有另一个好处,即避开了图数据在线预测过程中经常发生的冷启动问题。这意味着,即使新节点的邻居信息不能立即获得,MLP也能合理地推断出新节点的表征。然而,也因为MLP不依赖于邻居信息,导致MLP的表现通常比GNN差。这里我们提出一个问题:我们能否在GNN和MLP之间架起一座桥梁,既能享受MLP的低延迟和无图依赖性,又能达到和GNN表现得一样好?即,这篇文章想要构建一个模型,这个模型既有MLP的低延迟和无图依赖性的优点,又可以达到和GNN相同的表现/预测准确率,从而解决GNN延迟性高的问题。具体的分析如下:
由于图依赖性(graph dependency),GNNs有很大的推理延迟(inference latency)。GNN的层数多一层,就意味着要多获取一跳(1-hop)的邻居信息。在一个平均度数为R的图上(平均度数R意味着一个节点平均有R个邻居节点),使用L层的GNN来推断一个节点需要获取 的邻居信息。由于R在现实世界的图中通常很大,比如在Twitter数据集中R为208,且邻居信息必须逐层获取,所以GNN的推理时间(inference time)会随着L的增长迅速上升。图1展示了GNNs随着层数的增加所需的邻居节点个数和推理时间呈指数级上升。反观MLP,在相同的层数下,MLP的推理时间会少很多,且只是线性增长。这一推理时间的差距可以解释MLP比GNN在业界用的更广泛的原因。进一步,我们发现有两个因素加剧了邻居节点获取(node-fetching)的延迟:1. GNN的架构有越来越深的趋势,从64层到甚至1001层。2. 工业场景下的图很大,无法装入单台机器的内存,所以邻居节点获取需要从内存和硬盘的交互,这进一步导致了延迟。
另一方面,MLPs无法利用图的拓扑结构(graph topology),这降低了MLPs在节点分类(node classification)任务上的性能。例如,GraphSAGE在ogbn-Products数据集上的准确率为78.61%,而在相同层数的MLP上只有62.47%。然而,最近在CV和NLP领域的研究表明,大型的或是轻微修改过的MLPs可以达到与CNN和Transformer差不多的表现。
所以,这篇文章想要将GNN和MLP的优点结合起来,以获得高准确率且低延迟的模型。
为了结合GNN和MLP的优点来搭建一个高准确率且低延迟的模型, 这篇文章提出了一个模型叫做 Graph-less Neural Network (GLNN)。具体来说, GLNN是一种涉及从教师GNN到学生MLP的知识 蒸馏(knowledge distillation)模型。经过训练后的学生MLP即为最终的GLNN, 所以GLNN在训练 中享有图拓扑结构 (graph topology)的好处, 但在推理中没有图依赖性(graph dependency)。知识蒸馏(knowledge distillation KD)是在Hinton等人在2015年提出的一个范式, 可以将知识从一 个繁琐复杂的教师转移到一个更简单的学生。这篇文章就想要通过知识蒸馏(knowledge distillation)从复杂的教师GNN中训练一个优质的MLP。具体来说, 我们用教师GNN为每个节点 生成它的软目标(soft targets) , 然后我们用真实标签(true label/ground truth) 和教师GNN 生成的软目标 来训练学生MLP。GLNN的目标函数如公式1所示, 其中 是权重参数, 是真实标签 和学生MLP预测出的标签 的交叉熵(cross-entropy)函数, 是KL散度 (KL-divergence)。
经过KD训练完后的学生模型,即GLNN,本质上是一个MLP。因此,GLNN在推理过程中没有图依赖性(graph dependency),并且和MLP一样快。另一方面,通过KD,GLNN的参数会被优化到和GNN具体同样的预测和泛化的能力,并有更快的推理和更容易部署的额外好处。图2展示了GLNNs的框架图。
由于GLNN的模型非常简单,所以实验部分是这篇文章的重点,从各个角度证明并阐述了GLNN的有效性且有效的原因。实验部分主要是回答如下6个问题
我们首先将GLNN与具有相同层数和隐藏维度(hidden dimension)的MLPs和GNNs,在标准的直推设置(transductive setting)下进行比较。
实验结果如表1所示,所有GLNN的性能都比MLPs有很大的提高。在较小的数据集上(前5行),GLNNs甚至可以超过教师GNNs。 换句话说,在相同的条件下,存在一组MLP参数,其性能可以与GNN相匹敌。对于大规模(large-scale)OGB数据集(最后2行),GLNN的性能比MLP有所提高,但仍然比教师GNN差。然而这种GLNN和教师GNN在大规模数据集上的这种差距可以通过增加MLP的层数来缓解,如表2所示。由图3(右图)所示,一方面逐渐增加GLNN的层数可以使其性能接近于SAGE。另一方面,当SAGE的层数减少时,准确率会迅速下降到比GLNNs更差。
直推设置(transductive setting) 是对节点分类任务常见的设置,但在这个设置下模型只能对见过的节点进行进行预测(即会使用完整的邻接矩阵adjacency matrix)。为了更好地评估GLNN的表现,这篇文章还考虑了在归纳设置(inductive setting)下进行实验(即训练时用的邻接矩阵不是完整的,而是剔除了测试集的节点及相应的边)。具体的实验设置请参考原文。实验结果如表3所示。
在表3中,我们可以看到GLNN在归纳设置(inductive setting)下的性能仍能比MLP提升许多。在6/7个数据集上,GLNN的性能可以和GNNs的接近。 在大规模的Arxiv数据集上,GLNN的性能明显低于GNNs,文章中给出的原因是Arxiv数据集的数据分割(data split)比较特殊,会导致测试节点和训练节点之间的分布转移(distribution shift),从而使GLNN很难通过KD学习到邻居信息。
常见的推理加速(inference accerleration)方法有剪枝(pruning)和量化(quantization)。这两个方法通过减少模型的乘加运算(Multiplication-and-ACcumulation MAC)来进行加速。但本质上它们没有解决因为需要获取邻居信息(neighbor-fetching)而导致的延迟(latency)。所以,这一类方法在GNNs上加速的提升没有在NNs上的那么多。对于GNNs来说,邻居采样(neighbor sampling)也被用来减少延迟。所以在这一实验中,我们的基线(baseline)有SAGE、QSAGE(从FP32到INT8的量化SAGE模型)、PSAGE(有50%权重剪枝的SAGE)、Neighbor Sample(采样15个邻居 fan-out 15)。实验结果如表4所示,GLNN要比所有的基线快得多。
另两种被视为推理加速的方法是GNN-to-GNN的KD(即教师和学生网络都是GNN),比如TinyGNN和Graph Augmented-MLPs(GA-MLPs),比如SGC和SIGN。GNN-to-GNN KD的推理时间会比相同层数下的普通GNNs慢,因为通常会有一些额外的开销,比如TinyGNN中的Peer-Aware Module(PAM)。GA-MLPs通常需要预先计算增强的节点特征并对其应用MLPs。因为有预计算(预计算不计入推理时间),所以GA-MLPs的推理时间和MLPs相同。因此,对于GLNN和GNN-to-GNN和GA-MLPs的推理时间比较,可以等价为GLNN和GNN和MLP进行比较。实验结果如图3左图所示,GNN比MLP在推理上慢得多。而由于GA-MLPs无法对归纳节点进行完全的预计算,GA-MLPs仍然需要获取邻居节点,这会让它在归纳设置(inductive setting)中比MLP慢得多,甚至要比剪枝过的GNN和TinyGNN还要慢。所以,GLNN要比GNN-to-GNN和GA-MLPs的基线快得多。
通过损失曲线的图证明KD可以通过正则化(regularization)和归纳偏见的转移来找到使得MLP达到和GNN类似表现的参数,证明GLNN受益于教师输出中包含的图拓扑结构信息(graph topology knowledge)。具体细节,有兴趣的读者可以自行阅读原文。
直观来说,在节点分类任务中,利用邻居信息的GNN表现比MLP更强大。因此,MLPs是否与GNNs有同样的表现里来代表图数据呢?这篇文章给出的结论是:在节点特征(node attribute)非常丰富的前提下,MLPs和GNNs具有相同的表现力。具体细节,有兴趣的读者可以自行阅读原文。
当每个节点都由其的度(degree)或是否形成一个三角形来标记。那么MLPs就无法拟合即GLNN失效。但这种情况是非常罕见的。 对于实际的图机器学习任务,节点特征和该节点在拓扑结构中的角色往往是高度相关的。因此MLPs即使只基于节点特征也能取得合理的结果,而GLNNs则有可能取得更好的结果。
我们对GLNN的有噪声的节点特征、归纳设置下的分割率(inductive split rates)和教师GNN结构做了消融研究。下述仅展现结论,实验细节有兴趣的读者可自行阅读原文。
我们通过在节点特征(node features)中加入不同程度的高斯噪声以减少其与标签的相关性。我们用 表示添加噪声的程度。实验结果如图4-左图所示,随着 的增加,MLP和GLNN的性能比GNN下降得更快。当 较小时,GLNN和GNN的性能仍然相当。
在图5-中图,我们展示了不同分割率下MLP、GNN、GLNN的性能。根据实验结果,随着分割率的增加即归纳部分的增加,GNN和MLP的性能基本保持不变,而GLNN的归纳性能略有下降。 当遇到大量的新数据时,从业者可以选择在部署前对根据所有的数据重新训练模型。
在图5-右图中,我们展示了使用其他各种GNN作为教师的结果,比如GCN、GAT和APPNP。我们看到GLNN可以从不同的老师那里学习,并都比MLPs有所提高。 四个老师中,从APPNP中提炼出来的GLNN比其他老师的表现要稍差一些。一个可能的原因是,APPNP的第一步是利用节点自身的特征进行预测(在图上传播之前),这与学生MLP的做法非常相似,因此提供给MLP的额外信息比其他教师少,导致表现较差。
这篇文章研究了是否可以结合GNN和MLP的优点,以实现准确且快速的图机器学习模型部署。我们发现,从GNN到MLP的KD有助于消除推理图的依赖性,从而使GLNN比GNN快146×-273x倍且性能不会降低。 我们对GLNN的特性做了全面研究。在不同领域的7个数据集上取得的好结果表明,GLNN可以成为部署延迟约束模型的一个好选择。