Diffusion-Convolutional Neural Networks
Diffusion-Convolutional Neural Networks
-
- 模型详述
- Node Classification
- Graph Classification
- Edge Classification and Edge Features
模型详述

- A_t表示图的邻接矩阵,P_t表示度归一化转移矩阵,(P_t)_{ij}表示由节点i转移到j的概率,可以由A_t计算得到,可以认为是权重矩阵。
- A_t矩阵有一个性质 :A_t矩阵的幂级数 A^n_t中的一个元素(A^n_t)_{ij},表示节点i到节点j长度为n的游走(英文为walk)的数量,当不存在这样的游走时,该值为0,之后将该矩阵归一化后,就可以表示长度为n时,节点i转移到j的概率。这种性质对应与公式:
P^*_{tijk}=P^j_{tik}\tag1
其中,P^*_t\in R^{N_t\times H\times N_t}表示由P_t组成的幂级数 ;i表示节点i ;j表示跳(hop)为j,也就是游走的长度为j;k表示节点的第k个特征。从公式(1)可以看出来,P^*_t的元素(P^*_t)_{ijk}表示:游走长度为j时,节点i转移到节点k的概率。
- 所谓的hop,按照字面意思是“跳”,对于某一节点 n ,其H-hop的节点就是,从 n节点开始,跳跃H次所到达的节点,比如对于 n 节点的1-hop的节点,就是 n 节点的邻居节点。这里对于节点representation并不是采用一个向量来表示,而是采用一个矩阵进行表示,矩阵的第 i行就表示i-hop的邻接信息。


Node Classification
假设P^*_t为P_t的power\ serirs,既{P_t,P^1_t,P^2_t,…},其中,P^*_t\in R^{N_t\times H\times N_t},表示由P_t组成的幂级数
扩散卷积表示为 :
Z_{tijm}=f(W_{jm}^c\cdot \sum_{l=1}^{N_t}{P^*_{tijl}X_{tlm}})\tag2
其中,m表示所有节点的第m个特征,W_{jm}表示权重,Z_{tijm}表示:以节点i为中心,在第m个特征上,游走长度为j的节点信息的聚合值;\sum_{l=1}^{N_t}{P^*_{tijl}X_{tlm}}部分的意义是以概率方式对节点i的j跳节点的一个信息聚合,f为非线性激活函数。式(2)的张量表示形式为:
Z_t=f(W^c\bigodot P^*_tX_t)\tag{3}
其中,\bigodot表示逐元素相乘,W^c \in R^{H\times F},为训练权重;P^*_tX_t\in R^{N_t\times H\times F}表示每个节点的各个跳[0,H-1]的聚合信息;在计算W^c\bigodot P^*_tX_t,存在广播机制 ,会将W^c复制N_t遍,然后逐元素相乘;Z_t \in R^{N_t\times H\times F}。
在得到节点的扩散卷积表示Z_t之后,可以直接将Z_t送入全连接层;
P(Y|X)=softmax(f(W^dZ))\tag4
其中,在送入全连接层之前需要将Z_t展平,变成二维矩阵Z\in R^{N_t\times (HF)},W^d\in R^{(HF)\times C},C表示分类种数。
Graph Classification
图的扩散卷积表示:
Z_t=f(W^c\bigodot \frac{(1_{N_t})^TP^*_tX_t}{N_t})\tag{5}
其中P^*_tt的意义不变,1_{N_t}\in R^{N_t\times 1}表示将各个节点信息\in R^{H\times F}聚合的权重,是全为1的的向量;除以N_t得到平均值。W^c训练得到的加权权值。
图分类与节点分类的原理一致:
在得到图的扩散卷积表示Z_t之后,可以直接将Z_t送入全连接层;
P(Y|X)=softmax(f(W^dZ))\tag6
其中,在送入全连接层之前需要将Z_t展平,变成一维向量Z\in R^{(HF\times 1)},W^d\in R^{(HF)\times C},C表示分类种数。
Edge Classification and Edge Features
通过将每一条边转化为一个节点来进行训练和预测,这个节点与原来的边对应的首尾节点相连,转化后的图的邻接矩阵 A'_t可以直接从原来的邻接矩阵增加一个incidence\ matrix得到:

之后,使用 A'_t来计算 P'_t,并用来替换 P_t 来进行分类。
对于模型训练,使用梯度下降法,并采用early-stop方式得到最终模型。
Diffusion-Convolutional Neural Networks
github源码:https://github.com/jcatw/dcnn
