Advertisement

GraphSage 代码阅读笔记

阅读量:

relation也就是边 没有embedding

supervised_train.py中使用了基于节点分类的label来进行损失训练,并且不生成任何嵌入层。

注:改写说明

该代码基于图中各节点及其邻接关系进行损失计算,并通过这一过程生成表示数据。经过训练后能够生成每个节点对应的嵌入表示,并通过特定迭代器实现批量处理功能。

NodeMinibatchIterator__init__方法最后加上

复制代码
    train_node_set = set(self.train_nodes)
    valid_node_set = set(self.val_nodes)
    print("train_node_set size", len(train_node_set))
    print("valid_node_set size", len(valid_node_set))
    print("train_node_set valid_node_set intersect size",len(train_node_set.intersection(valid_node_set)))

打印结果

复制代码
    train_node_set size 9716
    valid_node_set size 1825
    train_node_set valid_node_set intersect size 0

EdgeMinibatchIterator__init__方法最后加上

复制代码
    train_edge_set = set(self.train_edges)
    valid_edge_set = set(self.val_edges)
    print("train_edge_set size", len(train_edge_set))
    print("valid_edge_set size", len(valid_edge_set))
    print("train_edge_set valid_edge_set intersect size", len(train_edge_set.intersection(valid_edge_set)))

打印结果

复制代码
    train_edge_set size 1336764
    valid_edge_set size 75407
    train_edge_set valid_edge_set intersect size 0

EdgeMinibatchIterator__init__方法最后改成

复制代码
    train_nodes = [n for n in G.nodes() if not G.node[n]['test'] and not G.node[n]['val']]
    print(len(train_nodes), 'train nodes')
    test_nodes = [n for n in G.nodes() if G.node[n]['test'] or G.node[n]['val']]
    print(len(test_nodes), 'test nodes')
    print("train test node intersect number", len(set(test_nodes).intersection(set(train_nodes))))

打印结果

复制代码
    9716 train nodes
    5039 test nodes
    train test node intersect number 0

总结

初始化每个节点的初始嵌入表示是基于Glove等词向量方法获得的;
该模型通过提取训练数据中各节点之间的关联网络特征;
这些关联网络特征能够推广到测试数据集上;
当训练集与验证集完全不重叠时,
从而为未见测试样本生成最终嵌入表示。

more understanding https://discuss.dgl.ai/t/graphsage-question-the-training-dataset-and-validation-dataset-have-no-overlap-then-how-do-these-validation-datasets-obtain-their-embeddings-for-the-downstream-task/539/3

全部评论 (0)

还没有任何评论哟~