Advertisement

什么是图神经网络?

阅读量:

一、概念

将GNN定义为一类专注于处理图结构数据的神经网络,并将其应用于多个现实应用场景中。该技术通过其节点与边之间的信息传递与融合操作,在捕捉复杂关联与特征特性方面展现出显著优势。

一般来说,在GNN中输入的结构是一个图 G=(V,E),其中 V 代表节点集合而 E 代表边集合。对于每一个属于 V 的元素 v∈V ,它可能具备一个特征向量

​,每条边 (u,v)∈E 可能有一个特征向量

​。

二、核心算法

GNN的核心思想是通过迭代更新节点表示来捕获图数据中的关键信息。在每一层中包含两个主要步骤:首先,在一个子步骤中对每个节点进行操作;其次,在另一个子步骤中对相邻节点进行操作以促进信息传播。

  • 信息传播(Message Passing):网络中的每个节点会从直接相连的相邻节点获取信息。
    • 网络更新(Node Update):每个网络单元会整合自身特征与周围环境的数据并进行相应调整。

假设我们有一个图 G=(V,E),每个节点 v∈V 的特征向量为

,每条边 (u,v)∈E 的特征向量为

​。GNN的计算公式可以表示为:

1、消息传递

其中,N(v)表示节点 v 的邻居节点集合,M是消息传递函数,

是节点 v 在第 k 层接收到的消息。

2、节点更新

其中,U是节点更新函数,

是节点 v 在第 k 层的表示。

三、python实现

在此基础上,我们开发了一个create_graph函数来构造一张代表空手道俱乐部团体关系网的图表,并赋予每个参与者独特的身份特征向量以及相应的群体类别标签。通过导入Karate Club数据集资源包,在这张社交网络图表中我们能够清晰地观察到34名成员及其之间的78条互动关系线。随后我们将这些成员划分为两个群体类别:一部分由活跃参与者Mr. Hi组成另一部分则由负责日常事务的Officer构成这样我们就建立了完整的数据库结构以便后续开展相关分析研究

复制代码
 import torch

    
 import torch.nn as nn
    
 import torch.optim as optim
    
 import torch.nn.functional as F
    
 import networkx as nx
    
 import numpy as np
    
 import matplotlib.pyplot as plt
    
  
    
 # 生成一个小的图数据集
    
 def create_graph():
    
     # 加载 Karate Club 图数据集,这是一个社交网络图,包含 34 个节点和 78 条边。
    
     G = nx.karate_club_graph()
    
     features = np.eye(G.number_of_nodes())
    
     # 为每个节点生成标签(0 或 1),表示节点属于哪个社区(Mr. Hi 或 Officer)。
    
     labels = np.array([G.nodes[i]['club'] == 'Mr. Hi' for i in range(G.number_of_nodes())], dtype=int)
    
     return G, features, labels
    
  
    
 # 定义原始GNN模型
    
 class GNN(nn.Module):
    
     def __init__(self, input_dim, hidden_dim, output_dim):
    
     super(GNN, self).__init__()
    
     self.fc1 = nn.Linear(input_dim, hidden_dim)
    
     self.fc2 = nn.Linear(hidden_dim, output_dim)
    
  
    
     def forward(self, x, adj):
    
     h = F.relu(self.fc1(x))
    
     # 使用邻接矩阵 adj 聚合邻居节点的信息。
    
     h = torch.matmul(adj, h)
    
     h = self.fc2(h)
    
     return F.log_softmax(h, dim=1)
    
  
    
 # 训练和测试函数
    
 def train(model, optimizer, features, labels, adj, train_mask, epochs=10):
    
     model.train()
    
     for epoch in range(epochs):
    
     optimizer.zero_grad()
    
     output = model(features, adj)
    
     # 计算负对数似然损失
    
     loss = F.nll_loss(output[train_mask], labels[train_mask])
    
     loss.backward()
    
     optimizer.step()
    
     print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')
    
  
    
 def test(model, features, labels, adj, mask):
    
     model.eval()
    
     with torch.no_grad():
    
     output = model(features, adj)
    
     pred = output[mask].max(1)[1]
    
     acc = pred.eq(labels[mask]).sum().item() / mask.sum().item()
    
     return acc
    
  
    
 # 主函数
    
 # 创建图数据集
    
 G, features, labels = create_graph()
    
 adj = nx.adjacency_matrix(G).todense()
    
 adj = torch.FloatTensor(adj)
    
 features = torch.FloatTensor(features)
    
 labels = torch.LongTensor(labels)
    
  
    
 # 训练和测试掩码,前 30 个节点用于训练
    
 train_mask = torch.BoolTensor([True if i < 30 else False for i in range(len(labels))])
    
 test_mask = ~train_mask
    
  
    
 # 初始化模型和优化器
    
 model = GNN(input_dim=features.shape[1], hidden_dim=16, output_dim=2)
    
 optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    
  
    
 # 训练模型
    
 train(model, optimizer, features, labels, adj, train_mask)
    
  
    
 # 测试模型
    
 train_acc = test(model, features, labels, adj, train_mask)
    
 test_acc = test(model, features, labels, adj, test_mask)
    
 print(f'Train Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}')
    
  
    
 # 可视化结果
    
 def plot_graph(G, labels, pred=None):
    
     pos = nx.spring_layout(G)
    
     plt.figure(figsize=(8, 8))
    
     nx.draw(G, pos, with_labels=True, node_color=labels, cmap=plt.cm.rainbow, node_size=500, font_color='white')
    
     if pred is not None:
    
     nx.draw_networkx_nodes(G, pos, node_color=pred, cmap=plt.cm.rainbow, node_size=200, alpha=0.5)
    
     plt.show()
    
  
    
 plot_graph(G, labels.numpy(), pred=model(features, adj).max(1)[1].numpy())

四、总结

GNN可以直接处理图结构数据。采用端到端的方式进行训练后,GNN不仅可以从原始图数据中自动提取特征和表示信息,并且还能够有效解决社交网络分析、分子结构预测以及知识图谱推理等问题。然而,在实际应用中若仅依赖传统的计算资源可能难以满足需求。由于每一轮迭代都需要完成消息传递和节点更新过程,在大规模数据环境下计算开销较大导致训练与推理速度较慢。特别是在深层次的GNN模型中,节点的表示可能会变得过于相似从而引发过平滑现象;此外,在实际应用中若图数据存在噪声或不完整性时其性能也会受到影响。

全部评论 (0)

还没有任何评论哟~