GraphSAGE论文阅读笔记
论文: Inductive Representation Learning on Large Graphs
1 Motivation
大多数graph embedding架构遵循transductive模式(直推式),仅能针对固定不变的图产生嵌入表示。相比之下,GraphSAGE则通过inductive方法实现了基于节点属性信息的有效嵌入生成。
统计机器学习可以将其划分为两类: transductive学习与inductive学习。从统计机器学习角度来看, 可以将其划分为两类: transductive 学习与 inductive 学习. transductive 学习:针对特定(测试)案例而言, 指的是测试集为固定样本; inductive 学习:测试集并非固定或非特定. 通常我们的目标是进行 inductive 学习. GNN中经典的 DeepWalk 和 GCN 方法都属于 transductive 学习.
2 前向传播
论文中提出的方法称为graphSAGE, 其中SAGE(Sample and Aggregate)是其核心概念。其中的主要步骤包括采样与聚合两部分。

该过程用于为红色的目标节点生成嵌入(embedding)。其中k代表从目标节点出发进行搜索的距离层级。当k等于1时,则为目标节点的所有直接邻居;而当k等于2时,则是指这些邻居所连接的所有非目标相关点。
首先进行采样操作:当k=1时,我们选取了与目标相关联的具体实例;接着在信息融合阶段(aggregation),通过聚合邻居信息来计算出目标 node 的嵌入表示;最后,在预测环节中(即利用之前得到的目标 node 的嵌入向量),来进行后续的任务推断。

在伪代码的第2至第7行中存在两层循环结构,在第一层循环中变量k从1遍历到最大深度K;第二层循环针对图中的每一个节点v执行操作。在这一过程中,
\mathcal{N}(v)
被定义为节点v的邻居集合,并且该集合是由伪代码中的neighborhood function进行计算得到的结果。具体而言,在第4行中,
\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}
表示该集合包含了所有与节点v直接相连(即其邻居)的所有节点的信息;而
\mathbf{h}_{\mathcal{N}(v)}^{k}
则是一个向量形式的数据结构。需要注意的是,在这里k-1并不代表相邻关系的存在性问题,
而是仅仅表明当前迭代次数k-1
所对应的邻居信息集。
此外,在第5行中,
我们将上一层循环获取到的信息与当前层信息进行融合,
即将相邻节点的信息与自身信息进行拼接处理。
通过特征初始化的方式生成每个节点的初始表示 h_{v}^{(0)}. 在迭代步骤中设置 k 值为 1 后, 经过内层 for 循环遍历所有顶点 v ∈ \mathcal{V} 的过程, 计算得到新的表征向量 h_{v}^{(1)}. 其中每个节点对应的表征向量能够反映其直接邻居的信息状态. 接着设置 k 值为 2, 经过这一层循环迭代后更新到新的表征向量 $h_{v}^{(2)}. 注意到是在 k 值递增的过程中, 每一步都能够在当前层累积邻居的影响程度. 因此在后续迭代中能够捕获更深层的网络结构信息.
同样的,每个节点的表征向量 h_v^k包含了深度为k的相邻节点的信息。
对图中的每个节点v执行K次循环操作后会生成一个表征向量集合这些集合能够汇总出与该节点深度不超过K层邻居的信息从而最终生成该节点的表示向量
通过分析网络拓扑结构的变化规律后发现:对于刚接入网络的新节点a而言,在已知其自身属性及其直接邻居信息的前提下即可推断出该节点对应的向量表示;这使得无需为其他所有节点重新计算向量表示成为可能;然而若希望进一步提升模型性能,则可以选择进行额外计算;但必须存储所有节点在深度k处的表示信息以便从h_a^{(0)}逐步推导出h_a^{(1)}, h_a^{(2)},一直到h_a^{(K)}
正如论文中所说的 :
The fundamental concept underlying Algorithm 1 lies in its approach during each search depth, where nodes collect information from their immediate neighbors. As the process continues to iterate, these nodes gradually accumulate increasingly more information from increasingly remote regions of The graph.
节点采样 Neighborhood definition
nodeSAGE并非全部采用所有相邻节点进行计算,而是采用了固定的采样尺寸进行处理.
3 参数训练
该论文采用无监督损失函数,在图嵌入模型中旨在使相邻节点具有类似的嵌入表示,并要求不相关节点之间的嵌入表示存在显著差异(类似于Word2Vec模型,在其学习过程中完全基于无监督的学习框架,并且与下游任务之间并无关联)。
J_{\mathcal{G}}\left(\mathbf{z}_{u}\right)被定义为-\log \sigma\left(\mathbf{z}_{u}^{\top} \mathbf{z}_{v}\right)减去Q乘以\mathbb{E}_{v_{n} \sim P_{n}(v)}中的\log \sigma\left(-\mathbf{z}_{u}^{\top} \mathbf{z}_{v_{n}}\right)
\mathbf{z}_{u}(其中\forall u \in \mathcal{V})表示图中的任意一个节点的embedding向量)。
v 可以表示为在一次固定长度随机游走过程中与节点u共同出现过的其他节点)。
P_n属于负采样分布)。
Q 表示的是负采样的总个数)。
这个loss本身并不算特别,在生成节点表示\mathbf{z}_u的过程中,并非简单地依赖于直接连接到该节点的边信息,在这种情况下就显得有些不同寻常了。其关键点就在于前向传播机制的独特性。
这种无监督设置忽视了下游任务的需求,在于与 downstream tasks中包含节点特征的设置类似。
如果一个节点的表示z仅用于特定下游任务时,在训练过程中无监督损失可被替换为该特定任务的损失函数(例如交叉熵损失)。这使得模型能够专注于优化目标相关的表示学习。
4 Aggregator Architectures
这篇文章尝试了多种aggregator function。
(aggregator 的作用是把一个向量的集合转换成向量,,也就是聚合。 )
与大多数机器学习任务中处理的数据不同,在图结构中(nodes' neighbors lack a natural order),aggregator function的作用域是无序向量集合 \left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}。因此该aggregator函数必须具备对称性特性(symmetric),或者在输入排列下保持不变(invariant to permutations of its input)。简而言之就是无论输入顺序如何都不会影响最终结果。
所以aggregator function有两个性质:
- 可微 differentiable
- symmetric
Mean aggregator
显然对向量集合,对应元素取均值是最直接的想法。
该文认为取均值与图卷积具有等价性,并推导出了一种GCN变体。(我对GCN还不够熟悉,目前尚未想到如何进一步推广)
将下述公式替换至伪代码中的第4至5行位置:
\mathbf{h}_{v}^{k} \leftarrow \sigma\left(\mathbf{W} \cdot \operatorname{MEAN}\left(\left\{\mathbf{h}_{v}^{k-1}\right\} \cup\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}\right)\right.
\mathbf{h}_{\mathcal{N}(v)}^{(k)} 表示在第 k 层时节点 v 的邻居集 \mathcal{N}(v) 的聚合结果。
\mathbf{h}_{v}^{(k)} 则表示在第 k 层时节点 v 的表示向量。
该层通过聚合操作将所有邻居节点的表示进行融合。
具体而言,
\mathbf{h}_{\mathcal{N}(v)}^{(k)} = \textit{\texttt{AGGREGATE}}_{~k}\bigl(\bigl\{\,\bm h_{u}^{\, (k-1)}} , ~u\in~\bm{\Gamma(v)}~\bigr\}\bigr)
随后,
\bm h_{~v}^{\, (k)}} = ~σ~~~(~ ~W_{~{} }^{\, (~~ )}} · ~CONCAT~~( ~ h_{~{} }^{\, (~~ )}}, ~ h_{~{} }^{\, (~~ )}} ))
其中,
W^ {( ~~ )}}
是一个 learnable 矩阵参数。
经过替换后,在计算阶段首先对节点v的前一层特征h_v^{k-1}以及其邻居节点的前一层特征集合\left\{\mathbf{h}_{u}^{k-1}, \forall u \in \mathcal{N}(v)\right\}进行求并集操作;随后将这些结果进行平均值计算;最后将该平均值乘以对应的权重系数
Long Short-Term Memory (LSTM) aggregator
Pooling aggregator
采用了池化机制作为聚合器,在计算时所有相邻节点的特征向量采用共享权重矩阵进行计算。经过一个非线性全连接层后,在其输出结果上执行最大池化操作以获取关键特征信息。
池化层上的聚合操作_{k}^{\text { Pool layer }}=\max作用于集合体\left\{\sigma\left(\mathbf{W}_{\text { pool }} \mathbf{h}_{u_{i}}^{k}+\mathbf{b}\right), \forall u_{i} \in \mathcal{N}(v)\right\}
5 实验结果
实验给了三个图,效果,效率,采样数量对效果和性能的影响 。
基于三个不同数据集的实验结果显示,在大多数情况下,LSTM或池化方法通常表现出较好的效果.相比无监督学习方法而言,有监督学习方法通常表现更为优异.
相较于仅使用特征的逻辑回归而言的效果都得到了提升。不确定是否与LightGBM相比表现如何

在79534个测试集上的推理时间是秒级的。(感觉上还行,也不是特别快)
右图表明随着采样数量的增加,运行时间呈现出近似线性的增长趋势.然而,反观实验结果表明效果并未呈现预期的线性提升.
代码
作者在论文中使用了TensorFlow,并因此开源了一个简单且易于扩展的PyTorch版本。
PyTorch版本中的两个数据集规模较小,并非论文中的数据集来源。
这两个数据集参考了Kipf等人于2016年发表的经典GCN论文。
其中节点数量分别为约2700和约2万。
Cora作为一个机器学习领域的论文引用数据库,在其中包含了2708篇论文及其引文关系网络。每个条目中的分类标签则标识该研究对象所处的具体学科领域。这些分类涉及七个主要研究方向:遗传算法、神经网络以及强化学习等七个重要领域。特征部分则是通过分词技术和去除停用词的方式得到的词汇集合,在此集合中每个项都标识了一个特定词汇的存在与否情况。
