点击上方 蓝字关注我们
VQ-GNN:使用矢量量化扩展图神经网络的通用框架
作者
Mucong Ding (Department of Computer Science, University of Maryland, College Park)
Kezhi Kong (University of Maryland, College Park)
Jingling Li (University of Maryland, College Park)
Chen Zhu (University of Maryland, College Park)
John P Dickerson (Carnegie Mellon)
Furong Huang (University of Maryland)
最新的图神经网络(GNNs)可以定义为一种图卷积的形式,它可以通过在直接邻居之间或更远的地方进行消息传递
比如下面的形式
对应的梯度计算
从梯度的公式可以看出来,梯度的计算也是一种message passing. 注: 第二个括号里面是通过非线性后向传播的梯度。
为了将这种GNN扩展到大图,也就是节点和边数量非常多的图。目前的工作主要包括相邻图、层图或子图采样技术,简单来说就是通过考虑传递小批量节点的消息来缓解“邻居爆炸”的问题。然而,基于采样的方法很难应用于GNN。
主要原因是:当图很大时,我们通过每次迭代中采样b个节点(b<<n=|V|)的消耗很多,怎么说呢?
假设节点索引是i_1,...,i_b,并且一小批节点特征由XB=X_<ib>,: 表示。为了有效地对任何模型进行小批量处理,我们希望(1)将Θ(b)信息提取到训练设备,(2)在每次迭代中花费Θ(Lb)训练时间(L:层数),(3)同时花费(n/b)次迭代来遍历整个数据集。然而,对于大多数GNN来说,同时满足这三个可扩展性要求本质上是困难的。别慌,继续往下看。L层图形卷积的感受野由
递归给出,其大小与L的大小成指数增长。因此,要在最小批量的b个节点上进行优化,每次迭代需要Ω(b*d^l)个输入和训练时间。对每层中的每个节点的邻居的子集进行采样不会改变对L的【指数】依赖。尽管层和子图采样可能在每次迭代中仅需要Ω(b)输入和Ω(Lb)训练时间,但是与全图训练相比,它们仅能够考虑成指数级小比例的消息。最重要的是,所有现有的采样方法都不支持具有O(n^2)个非零项的稠密卷积矩阵. 下面是目前方法时间复杂度一览,最后是作者的方法,看起来很优秀。
作者提出了一种新的方法,VQ-GNN,一个通用的框架,可以在不影响性能的情况下,使用矢量量化(VQ)来扩展目前基于卷积的GNN。与基于采样的技术相比,作者的方法通过学习和更新全局节点表示的少量量化参考向量,在每个GNN层内使用VQ,可以有效地保存传递到小批量节点的所有消息。
作者的框架使用量化表示和低秩的图卷积矩阵相结合,避免了GNN的“邻居爆炸”问题。
并且证明了这种紧凑的卷积矩阵的低秩形式在理论和实验上都是足够的。
结合矢量量化,作者设计了一个新的近似消息传递算法和一个非平凡的反向传播规则。
在不同类型的GNN骨干上的实验表明,该框架在大图节点解密和链接预测基准上具有良好的可扩展性和性能。
图1:在框架VQ-GNN中,每个小批量消息传递(左)由一个VQ码本更新(中)和一个近似消息传递(右)近似。当前小批量中传递给节点的所有消息都会得到有效保存。圆是节点,矩形是VQ码字。双圆表示当前小批次中的节点。颜色表示码字分配。在VQ码本更新期间,刷新小批量中节点的码字分配(节点1),并且使用分配的节点更新码字。在近似消息传递期间,来自小批次外节点的消息被来自相应码字的消息近似,来自分配给相同码字的节点的消息被合并(a和b),并且小批次内消息不改变(c和d)
那么什么是VQ: 简单来说就是数据降维
向量量化[28](VQ)是一种广泛使用的方法,用于以确定性且identity-preserving的方式对数据进行降维,这是一种经典的数据压缩算法,它可以表示为以下优化问题.
这是经典的通过k-均值求解的方法[28]。在这里,\tilde{X}的特征sketch称为特征“码字(codewords)”。R被称为分配矩阵,它的行是单位向量,即R_{i,v}=1 if and on if 第i个节点被分配给k-Means中的第v个cluster。公式(5)称为簇内平方和(Within-Cluster Sum of SquaresWcss),我们可以将VQ的相对误差定义为=‖X−RX‖_F/‖X‖_F.
tilde{X}的行是k个码字(即k-Means中的质心),并且可以计算,一般来说,VQ为我们提供了一个principled的框架来学习低维X,以一种确定性的和节点身份保持的方式。但是,要使用VQ对GNN进行小批量训练和推理,还需要回答三个问题:·
如何利用学习到的码字逼近节点的前向小批量特征?
·如何通过矢量量化反向传播并估计节点的小批量梯度?
·如何在GNN的训练过程中更新码字和分配矩阵?
具体的做法请参考原文,论文主要解决上面的几个问题。
除此之外,作者还证明了这种方法的error bound.