Advertisement

【论文阅读-ASAP】ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations

阅读量:

论文链接:https://www.aaai.org/Papers/AAAI/2020GB/AAAI-RanjanE.8336.pdf 代码仓库:https://github.com/malllabiisc/ASAP

图神经网络(GNN)已被广泛应用于建模复杂网络数据,并已在多种任务中取得显著成效。
近年来,在图池化概念的定义方面取得了显著进展。
目前的研究者们主要致力于通过信息聚合和降维技术来生成有效的图级表示。
然而,
现有的池化方法主要存在以下两种局限性:
其一是无法有效捕捉复杂的子结构;
其二是缺乏良好的扩展性。

模型架构

ASAP

与该论文中提出的SAGPool架构具有相似的分层池化机制

(a.) 将图表输入至ASAP系统中。
(b.) 基于1-hop邻域的ASAP初始聚类过程会考虑所有节点作为初始medoid中心点。为了简化操作过程, 我们将节点2和6分组表示为medoid类, 并通过M2T注意力机制计算各簇的隶属度(参考第4.2节内容)。
(c.) 集群的评分将通过LEConv方法进行计算。
(d) 从合并图中选取具有最高得分的一组簇, 并基于所选簇成员之间的连接权重重新构建邻接矩阵。
(e) ASAP输出结果。
(f) 关于层次图分类架构的概述完成。

卷积层

GCN

池化层

P首先考察所有具有固定接受域的局部聚类的可能性。接着通过注意力机制计算各节点的簇归属。随后运用图神经网络评估各集群的质量。将合并图中质量分数最高的几个集群选作新的节点,并在此基础上重新计算相邻集群之间的连接权重。

Master2Token (M2T)

基于簇ch(vi)的条件下

首先创建一个代表集群中所有节点的 m i :

主查询
复制代码
    X_q, _ = scatter_max(x_pool_j, edge_index[0], dim=0)
    # NxF
    M_q = self.lin_q(X_q)    
    # ExF
    # M_q->mi
    M_q = M_q[edge_index[0].tolist()]

计算注意力分数:
具体而言,在传统的自注意力机制中,默认情况下查询、键和值是等价的。即通过将簇的代表节点mi作为查询输入到自注意力层中,并利用其对其他节点进行加权聚合以生成最终的表示向量。

注意力权值
复制代码
    score = self.gat_att(torch.cat((M_q, x_pool_j), dim=-1))
    score = F.leaky_relu(score, self.negative_slope)
    score = softmax(score, edge_index[0], num_nodes=num_nodes.sum())

聚合:

在这里插入图片描述
复制代码
    # Sample attention coefficients stochastically.
    score = F.dropout(score, p=self.dropout_att, training=self.training)
    # 公式7
    # ExF
    v_j = x_j * score.view(-1, 1)
    #---Aggregation---
    # NxF
    out = scatter_add(v_j, edge_index[0], dim=0)

LEConv

通过随后的节点embedding进行score calculation,其中score function employs LEConv,该方法基于节点间增量进行评估

在这里插入图片描述
复制代码
    num_nodes = x.shape[0]
    h = torch.matmul(x, self.weight)
    
    if edge_weight is None:
    edge_weight = torch.ones((edge_index.size(1), ),
                             dtype=x.dtype,
                             device=edge_index.device)
    edge_index, edge_weight = remove_self_loops(edge_index=edge_index, edge_attr=edge_weight)
    deg = scatter_add(edge_weight, edge_index[0], dim=0, dim_size=num_nodes) #+ 1e-10
    
    h_j = edge_weight.view(-1, 1) * h[edge_index[1]]
    aggr_out = scatter_add(h_j, edge_index[0], dim=0, dim_size=num_nodes)
    out = ( deg.view(-1, 1) * self.lin1(x) + aggr_out) + self.lin2(x)
    edge_index, edge_weight = add_self_loops(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes)

topk

复制代码
    fitness = torch.sigmoid(self.gnn_score(x=out, edge_index=edge_index)).view(-1) #这里的gnn_score即LEConv层
    perm = topk(x=fitness, ratio=self.ratio, batch=batch)
    x = out[perm] * fitness[perm].view(-1, 1)

更新图

复制代码
    batch = batch[perm]
    edge_index, edge_weight = graph_connectivity(
    device = x.device,
    perm=perm,
    edge_index=edge_index,
    edge_weight=edge_weight,
    score=score,
    ratio=self.ratio,
    batch=batch,
    N=N)
复制代码
    def graph_connectivity(device, perm, edge_index, edge_weight, score, ratio, batch, N):
    r"""graph_connectivity: is a function which internally calls StAS func to maintain graph connectivity"""
    
    kN = perm.size(0)
    perm2 = perm.view(-1, 1)
    
    # mask contains bool mask of edges which originate from perm (selected) nodes
    mask = (edge_index[0]==perm2).sum(0, dtype=torch.bool)
    
    # create the S
    S0 = edge_index[1][mask].view(1, -1)
    S1 = edge_index[0][mask].view(1, -1)
    index_S = torch.cat([S0, S1], dim=0)
    value_S = score[mask].detach().squeeze()
    
    # relabel for pooling ie: make S [N x kN]
    n_idx = torch.zeros(N, dtype=torch.long)
    n_idx[perm] = torch.arange(perm.size(0))
    index_S[1] = n_idx[index_S[1]]
    
    # create A
    index_A = edge_index.clone()
    if edge_weight is None:
        value_A = value_S.new_ones(edge_index[0].size(0))
    else:
        value_A = edge_weight.clone()
    
    fill_value=1
    index_E, value_E = StAS(index_A, value_A, index_S, value_S, device, N, kN)
    index_E, value_E = remove_self_loops(edge_index=index_E, edge_attr=value_E)
    index_E, value_E = add_remaining_self_loops(edge_index=index_E, edge_weight=value_E, 
        fill_value=fill_value, num_nodes=kN)
    
    
    return index_E, value_E

Readout

复制代码
    def readout(x, batch):
    x_mean = scatter_mean(x, batch, dim=0)
    x_max, _ = scatter_max(x, batch, dim=0) 
    return torch.cat((x_mean, x_max), dim=-1)
复制代码
    xs += readout(x, batch)

分类

复制代码
     x = F.relu(self.lin1(xs))
     x = F.dropout(x, p=0.5, training=self.training)
     x = self.lin2(x)
     out = F.log_softmax(x, dim=-1)

总结

主要是在池化层丢弃节点之前执行了注意力聚合操作,并保存了被丢弃节点的相关信息。同时基于节点的增量特征改进得分计算机制。

全部评论 (0)

还没有任何评论哟~