Advertisement

学习论文《Neural Snowball for Few-Shot Relation Learning》笔记

阅读量:

笔记内容主要基于AI科技评论的相关报道和分析,并引用了其关于‘用于少次关系学习的神经网络雪球机制’的相关介绍链接:https://mp.weixin.qq.com/s?__biz=MzI5NTIxNTg..."

1.文章的创新点

关系增长的关系抽取

关系抽取(RelationExtraction)是自然语言处理当中的一个核心研究课题。该方法旨在探索如何从文本中提取结构化的关联事实。

例如,在"比尔盖茨是微软的创始人"这一语境下, 我们能够提取出(比尔盖茨, 建造者, 微软)这样一个关键三元组, 并将其应用于多个 downstream applications如知识图谱构建等。

相关领域的现有研究工作也很多, 但大多数工作都是围绕预先定义好的特定关系类型展开的。即给定一个人工标注好的关系集合, 研究者们主要关注集合内部 predefined 的关系类型。

然而我们正面临一种开放式的动态扩展趋势:随着新领域的涌现以及新知识的不断涌现, 关系类型也在持续增加和发展变化。

知识度量:基于RSN的方法在大规模已有的知识图谱(Knowledge Graph)关系数据上进行距离度量的学习,并将其迁移至新的知识领域中。具体而言,在这一过程中,我们以新知识领域的少量实例作为初始启动样本,并通过从大规模的无监督学习数据中提取有价值的信息来不断优化我们的模型参数;通过持续挖掘更多有价值的信息(useful information),我们可以逐步构建出一个更为优秀的知识抽取(Knowledge Extraction)模型。

2.1 Three different kinds of data

在当前的研究领域中,
主要涉及以下几种关系抽取场景,
它们分别关注的不同类型关系及其使用的数据类型也有所差异:

  • 在有监督的学习框架下研究了预定义的标准关系类别,
    其中模型基于大规模的 supervision 数据 进行训练;
  • 半监督学习方法则旨在结合有限标签信息与丰富 unlabeled 数据资源,
    通过逐步优化分类器参数来提升性能;
  • 少样本学习技术专注于识别尚未见过的新类别的关系类型,
    其核心思想是在已有知识积累的基础上快速适应新领域;
  • 自启动学习方法则通过迭代机制从广泛的数据集中挖掘潜在模式,
    从而构建更具泛化的 relation extractor 模型。

从上面的分析中可以看出,这些方法涉及到了三种类型的数据:

  • 基于现有关系的大规模监督学习数据(large-scale existing relations data)
  • 针对新关系的小样本标注学习(The new relation few-shot ins)
  • 此外,在语料库里获取的大规模无监督学习的数据中包含着未见过的实例或未标记的数据(large-scale Unlabelled Corpora data)

我们希望充分利用这三种数据,并为此, 作者提出了基于...的关系抽取方法——Snowball方法

2.2 Nerual Snowball

三种数据类型的Nerual Snowball模型如下

在这里插入图片描述

基于图中所示的方法中

3. Methodology

Nerual Snowball的构成如下

在这里插入图片描述
  • 初始集合S_r: 对于一个新的关系类型,在每一轮迭代中使用少量初始标注样本作为S_r的起始点,经过选择,在下一轮迭代中将新增一批来自大规模无监督数据U_l中的unlabelled ins,更新后的S_r再次参与迭代过程。
    • 候选集合C1: C1中的实例都是通过远程监督(distant supervision)挖掘得到的实例集合,如图所示,如果新的关系类型是"founder",那么对于种子集中少量正确标注样例"Bill founder Microsoft"来说,远程监督将获得包含Bill和Microsoft实体对的所有句子.然而这些句子中存在不包含"founder"关系的情况(Bill提到Microsoft),因此这类无法体现新关系"founder"的句子是我们不需要的.因此我们将这些不符合条件的句子排除在外,称之为候选集合C1.
  • RSN: 用于筛选远程监督挖掘信息的有效性指标
  • Relation classifierg(x):是我们模型最终的目标二分类器。从C1中筛选出的实例与标注实例之间存在监督关系训练该关系分类器。通过此分类器对candidate1进行预测以获得candidate2。
  • C2: 由于从C1中获取的数据并未彻底去杂, 即使使用了二分类器Relation classifier模型, 其性能仍需进一步提升, 因此我们将未完全去杂的数据也纳入候选集合中作为补充。经过再次过滤后我们基本确认通过二分类器获得的句子(Steven Jobs Apple)质量较高, 并将其作为待选句子补充到种子集中用于后续迭代过程。

Neural Snowball 整个流程如下所述:

接收一种新的关联类型以及少量标记数据(启动种子)。

目的是训练一个这种关联类型的分类模型。采用分类模型的原因在于此方法具有更高的扩展性特征——当关联类型数量增加时可方便地整合多个分类模型进行协同工作。

训练过程从启动种子开始逐步迭代挖掘无监督数据中的有价值信息。

每一轮迭代主要分为两个阶段:

(1) 利用远监督获取待选句子;
(2) 利用新的关系分类器获取待选句子。

远监督(Distant Supervision)是一种机制,在基于现有知识库的前提下识别出所有包含实体对(h,t)的语句,并进一步假设这些语句确实表达了关联关系r。在获得新的训练集后,Neural Snowball系统会重新训练一个新的关系分类器,并将从无监督学习过程中筛选出被认为与特定关联关系r相关的样本。这些筛选出的数据有助于提升分类器的性能。

4.Neural Modules

Nerual Snowball有两个关键的components:RSN和Relation Classifier

4.1 RSN s(x,y)
  • 输入由两个实例构成,例如上图所示通过远程监督获得的两句话.
    • 具体来说,每个实例包含:'Bill founder Microsoft' 和 'Bill mentioned Microsoft'(显然这两个实例并不表示相同的关系).
    • 输出结果为二进制形式(即0或1).
Structure of RSN

RSN包含两个encoder f_s以及一个distance函数其架构如图所示它接收两个句子作为输入并输出这两个句子是否表达相同的关系我们基于大量已有的有标签关系数据进行预训练然后将其应用于Neural Snowball框架中从无监督学习得到的一批候选样本集合中使用RSN模型对每个样本分别与预设的启动种子进行比较评估仅保留分类准确率高于设定阈值的候选数据

在这里插入图片描述

在RSN架构中, 编码器的作用是将实例映射到其表示向量中. 其中这两个编码器采用参数共享机制(parameter sharing), 即当采用CNN作为编码器时, 卷积核同样遵循参数共享原则, 这样不仅能够提高模型效率还能降低计算复杂度. (参数共享: CNN中的参数共享理解

在RSN框架中定义的距离函数是用来衡量similarity的一种方法:s(x,y)=\sigma (\mathcal{w}_s^T(f_s(x)-f_s(y))^2+b_s)

4.2 Relation Classifier g(x)

在RC中存在神经编码器f,其作用是将新的关系类型x转换为实值向量;通过一个线性层能够计算输入实例属于某种关系的概率。具体而言,在数学上可表示为:通过一个线性层能够计算输入实例属于某种关系的概率。具体而言,在数学上可表示为:

g(x) = \sigma(w^T f(x) + b)

其中\sigma代表sigmoid函数。
当需要区分多种关系时(即多标签分类问题),通常会采用多个二分类器来分别处理每一种可能的关系类型(即多标签分类)。然而,在此过程中由于 raw 数据中的新关系类型不断增多,在此情况下作者选择不采用 N-ary 分类器而是采用逐个处理的方法来简化模型结构。

4.3 Pre-training and Fine-tuning
Pre-training

预训练任务就是训练RSN和RC网络模型的过程,在以后的迭代过程中这些参数就不改变了。我们基于现有的标注数据集(existing labeled datasetS_N)进行监督学习过程。

在RSN中, 首先进行采样, 其中S_N代表...这些instance pairs既包括相同关系类型也包括不同关系类型的数据, 通过使用交叉熵损失函数进行训练.

对于RC而言,在其种子集S_r(seed set)中选取了minibatch数量的正样本,并在背景集S_N中选取了相同数量的负样本。随后,在RC所在的linear layer(被估计为fully connected层)中对参数W和b进行参数训练优化。其损失函数表示为:L_{\mathcal{S}_b,\mathcal{T}_b}(g_{w,b})=\sum_{x\in{\mathcal{S_b}}}log g_{w,b}(x)+\mu \sum_{x\in{\mathcal{T_b}}}log (1-g_{w,b})

Fine-tuning

因为预训练过程主要是训练RSN和RC网络模型,在后续迭代过程中这些参数不再发生变化。因此,在进行微调时仅会对baseline-model进行优化调整

5.瞅瞅代码

RSN的训练:encode被定义为RSN内部的一个编码函数;而forward_infer则采用了文中所述的方法来计算score;此外forward_infer_sort对应的是Snowball类中phase1阶段的方法B。无论哪种前向传播方式,在此方案中均整合到统一的前向传播机制中。具体实现细节可参考注释部分

复制代码
    class Siamese(nn.Module):
    
    def __init__(self, sentence_encoder, hidden_size=230, drop_rate=0.5, pre_rep=None, euc=True):
        nn.Module.__init__(self)
        self.sentence_encoder = sentence_encoder # Should be different from main sentence encoder !!!
        self.hidden_size = hidden_size
        # self.fc1 = nn.Linear(hidden_size * 2, hidden_size * 2)
        # self.fc2 = nn.Linear(hidden_size * 2, 1)
        self.fc = nn.Linear(hidden_size, 1)
        self.cost = nn.BCELoss(reduction="none")
        self.drop = nn.Dropout(drop_rate)
        self._accuracy = 0.0
        self.pre_rep = pre_rep
        self.euc = euc
    
    def forward(self, data, num_size, num_class, threshold=0.5):
        # view : 将x处理成num_class行,每行元素为num_size行1列
        x = self.sentence_encoder(data).contiguous().view(num_class, num_size, -1)
        # view: x1,x2,y1,y2处理成1行,每行(也就是该行)有hidden_size个元素(hidden_size列),//是整除(防小数)
        # 其中  [:, :num_size//2]是遍历所有行(class),取到每行的第num_size//2个元素之前的所有元素
        #       [:, num_size//2:]是遍历所有行,取到每行的第num_size//2个元素之后的所有元素,相当于x1和x2分开了num_size
        #       [:num_class//2,:]是遍历所有列,取到每列的第num_size//2个元素之前的所有元素,相当于y1和y2分开了所有的num_class
        x1 = x[:, :num_size//2].contiguous().view(-1, self.hidden_size)
        x2 = x[:, num_size//2:].contiguous().view(-1, self.hidden_size)
        y1 = x[:num_class//2,:].contiguous().view(-1, self.hidden_size)
        y2 = x[num_class//2:,:].contiguous().view(-1, self.hidden_size)
        # y1 = x[0].contiguous().unsqueeze(0).expand(x.size(0) - 1, -1, -1).contiguous().view(-1, self.hidden_size)
        # y2 = x[1:].contiguous().view(-1, self.hidden_size)
    
        label = torch.zeros((x1.size(0) + y1.size(0))).long().cuda()
        label[:x1.size(0)] = 1  #x1的label全标签为1
        z1 = torch.cat([x1, y1], 0)
        z2 = torch.cat([x2, y2], 0)
    
        if self.euc:
            dis = torch.pow(z1 - z2, 2)
            dis = self.drop(dis)
            score = torch.sigmoid(self.fc(dis).squeeze())
        else:
            z = z1 * z2
            z = self.drop(z)
            z = self.fc(z).squeeze()
            # z = torch.cat([z1, z2], -1)
            # z = F.relu(self.fc1(z))
            # z = self.fc2(z).squeeze()
            score = torch.sigmoid(z)
    
        self._loss = self.cost(score, label.float()).mean()
        pred = torch.zeros((score.size(0))).long().cuda()
        pred[score > threshold] = 1
        self._accuracy = torch.mean((pred == label).type(torch.FloatTensor))
        pred = pred.cpu().detach().numpy()
        label = label.cpu().detach().numpy()
        self._prec = float(np.logical_and(pred == 1, label == 1).sum()) / float((pred == 1).sum() + 1)
        self._recall = float(np.logical_and(pred == 1, label == 1).sum()) / float((label == 1).sum() + 1)
    
    def encode(self, dataset, batch_size=0): 
        if self.pre_rep is not None:
            return self.pre_rep[dataset['id'].view(-1)] 
    
        if batch_size == 0:
            x = self.sentence_encoder(dataset)
        else:
            total_length = dataset['word'].size(0)
            max_iter = total_length // batch_size
            if total_length % batch_size != 0:
                max_iter += 1
            x = []
            for it in range(max_iter):
                scope = list(range(batch_size * it, min(batch_size * (it + 1), total_length)))
                with torch.no_grad():
                    _ = {'word': dataset['word'][scope], 'mask': dataset['mask'][scope]}
                    if 'pos1' in dataset:
                        _['pos1'] = dataset['pos1'][scope]
                        _['pos2'] = dataset['pos2'][scope]
                    _x = self.sentence_encoder(_)
                x.append(_x.detach())
            x = torch.cat(x, 0) #concatenate
        return x
    
    #使用method A,有阈值
    def forward_infer(self, x, y, threshold=0.5, batch_size=0):
        x = self.encode(x, batch_size=batch_size)
        support_size = x.size(0)
        y = self.encode(y, batch_size=batch_size)
        # a.squeeze(N) 就是在a中指定位置N加上一个维数为1的维度
        x = x.unsqueeze(1)  #N = 1
        y = y.unsqueeze(0)  #N = 0
    
        if self.euc:
            dis = torch.pow(x - y, 2)   # L2 distance
            score = torch.sigmoid(self.fc(dis).squeeze(-1)).mean(0)    #nn.linear自有权重和bias,得到score
        else:
            z = x * y
            z = self.fc(z).squeeze(-1)
            score = torch.sigmoid(z).mean(0)
    
        pred = torch.zeros((score.size(0))).long().cuda()
        pred[score > threshold] = 1
        pred = pred.view(support_size, -1).sum(0)
        pred[pred < 1] = 0
        pred[pred > 0] = 1
        return pred
    
    # 使用sort方法
    def forward_infer_sort(self, x, y, batch_size=0):
        x = self.encode(x, batch_size=batch_size)
        support_size = x.size(0)
        y = self.encode(y, batch_size=batch_size)
        x = x.unsqueeze(1)
        y = y.unsqueeze(0)
    
        if self.euc:
            dis = torch.pow(x - y, 2)
            score = torch.sigmoid(self.fc(dis).squeeze(-1)).mean(0)
        else:
            z = x * y
            z = self.fc(z).squeeze(-1)
            score = torch.sigmoid(z).mean(0)
    
        pred = []
        for i in range(score.size(0)):
            pred.append((score[i], i))
        pred.sort(key=lambda x: x[0], reverse=True)
        return pred

Snowball模块中的代码块包含了大量用于数据读取的功能模块,在forward_train阶段启动了后续流程。其中,在forward_train阶段开始之前就完成了从input源以及phase1和phase2阶段的数据处理整合工作。此外,在开发人员对relation classifier的负采样机制进行了详细说明这一过程中也值得特别关注

复制代码
    class Snowball(nrekit.framework.Model):
    
    def __init__(self, sentence_encoder, base_class, siamese_model, hidden_size=230, drop_rate=0.5, weight_table=None, pre_rep=None, neg_loader=None, args=None):
        nrekit.framework.Model.__init__(self, sentence_encoder)
        self.hidden_size = hidden_size
        self.base_class = base_class
        self.fc = nn.Linear(hidden_size, base_class)
        self.drop = nn.Dropout(drop_rate)
        self.siamese_model = siamese_model
        # self.cost = nn.BCEWithLogitsLoss()
        self.cost = nn.BCELoss(reduction="none")
        # self.cost = nn.CrossEntropyLoss()
        self.weight_table = weight_table
        
        self.args = args
    
        self.pre_rep = pre_rep
        self.neg_loader = neg_loader
    
    # def __loss__(self, logits, label):
    #     onehot_label = torch.zeros(logits.size()).cuda()
    #     onehot_label.scatter_(1, label.view(-1, 1), 1)
    #     return self.cost(logits, onehot_label)
    
    # def __loss__(self, logits, label):
    #     return self.cost(logits, label)
    
    def forward_base(self, data):
        batch_size = data['word'].size(0)
        x = self.sentence_encoder(data) # (batch_size, hidden_size)
        x = self.drop(x)
        x = self.fc(x) # (batch_size, base_class)
    
        x = torch.sigmoid(x)
        if self.weight_table is None:
            weight = 1.0
        else:
            weight = self.weight_table[data['label']].unsqueeze(1).expand(-1, self.base_class).contiguous().view(-1)
        label = torch.zeros((batch_size, self.base_class)).cuda()
        label.scatter_(1, data['label'].view(-1, 1), 1) # (batch_size, base_class)
        loss_array = self.__loss__(x, label)
        self._loss = ((label.view(-1) + 1.0 / self.base_class) * weight * loss_array).mean() * self.base_class
        # self._loss = self.__loss__(x, data['label'])
        
        _, pred = x.max(-1)
        self._accuracy = self.__accuracy__(pred, data['label'])
        self._pred = pred
    
    def forward_baseline(self, support_pos, query, threshold=0.5):
        '''
        baseline model
        support_pos: positive support set
        support_neg: negative support set
        query: query set
        threshold: ins whose prob > threshold are predicted as positive
        '''
        
        # train
        self._train_finetune_init()
        # support_rep = self.encode(support, self.args.infer_batch_size)
        support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
        # self._train_finetune(support_rep, support['label'])
        self._train_finetune(support_pos_rep)
    
        
        # test
        query_prob = self._infer(query, batch_size=self.args.infer_batch_size).cpu().detach().numpy()
        label = query['label'].cpu().detach().numpy()
        self._baseline_accuracy = float(np.logical_or(np.logical_and(query_prob > threshold, label == 1), np.logical_and(query_prob < threshold, label == 0)).sum()) / float(query_prob.shape[0])
        if (query_prob > threshold).sum() == 0:
            self._baseline_prec = 0
        else:        
            self._baseline_prec = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((query_prob > threshold).sum())
        self._baseline_recall = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((label == 1).sum())
        if self._baseline_prec + self._baseline_recall == 0:
            self._baseline_f1 = 0
        else:
            self._baseline_f1 = float(2.0 * self._baseline_prec * self._baseline_recall) / float(self._baseline_prec + self._baseline_recall)
        self._baseline_auc = sklearn.metrics.roc_auc_score(label, query_prob)
        if self.args.print_debug:
            print('')
            sys.stdout.write('[BASELINE EVAL] acc: {0:2.2f}%, prec: {1:2.2f}%, rec: {2:2.2f}%, f1: {3:1.3f}, auc: {4:1.3f}'.format( \
                self._baseline_accuracy * 100, self._baseline_prec * 100, self._baseline_recall * 100, self._baseline_f1, self._baseline_auc))
            print('')
    
    def __dist__(self, x, y, dim):
        return (torch.pow(x - y, 2)).sum(dim)
    
    def __batch_dist__(self, S, Q):
        return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3)
    
    def forward_few_shot_baseline(self, support, query, label, B, N, K, Q):
        support_rep = self.encode(support, self.args.infer_batch_size)
        query_rep = self.encode(query, self.args.infer_batch_size)
        support_rep.view(B, N, K, -1)
        query_rep.view(B, N * Q, -1)
        
        NQ = N * Q
         
        # Prototypical Networks 
        proto = torch.mean(support_rep, 2) # Calculate prototype for each class
        logits = -self.__batch_dist__(proto, query)
        _, pred = torch.max(logits.view(-1, N), 1)
    
        self._accuracy = self.__accuracy__(pred.view(-1), label.view(-1))
    
        return logits, pred
    
    #    def forward_few_shot(self, support, query, label, B, N, K, Q):
    #        for b in range(B):
    #            for n in range(N):
    #                _forward_train(self, support_pos, None, query, distant, threshold=0.5):
    #
    #        '''
    #        support_rep = self.encode(support, self.args.infer_batch_size)
    #        query_rep = self.encode(query, self.args.infer_batch_size)
    #        support_rep.view(B, N, K, -1)
    #        query_rep.view(B, N * Q, -1)
    #        '''
    #        
    #        proto = []
    #        for b in range(B):
    #            for N in range(N)
    #        
    #        NQ = N * Q
    #         
    #        # Prototypical Networks 
    #        proto = torch.mean(support_rep, 2) # Calculate prototype for each class
    #        logits = -self.__batch_dist__(proto, query)
    #        _, pred = torch.max(logits.view(-1, N), 1)
    #
    #        self._accuracy = self.__accuracy__(pred.view(-1), label.view(-1))
    #
    #        return logits, pred
    
    def _train_finetune_init(self):
        # init variables and optimizer
        self.new_W = Variable(self.fc.weight.mean(0) / 1e3, requires_grad=True)
        self.new_bias = Variable(torch.zeros((1)), requires_grad=True)
        self.optimizer = optim.Adam([self.new_W, self.new_bias], self.args.finetune_lr, weight_decay=self.args.finetune_wd)
        self.new_W = self.new_W.cuda()
        self.new_bias = self.new_bias.cuda()
    
    # 对relation classfier的训练
    def _train_finetune(self, data_repre, learning_rate=None, weight_decay=1e-5):
        '''
        train finetune classifier with given data
        data_repre: sentence representation (encoder's output)
        label: label
        '''
        
        self.train()
    
        optimizer = self.optimizer
        if learning_rate is not None:
            optimizer = optim.Adam([self.new_W, self.new_bias], learning_rate, weight_decay=weight_decay)
    
        # hyperparameters
        max_epoch = self.args.finetune_epoch
        batch_size = self.args.finetune_batch_size
        
        # dropout
        data_repre = self.drop(data_repre) 
        
        # train
        if self.args.print_debug:
            print('')
        for epoch in range(max_epoch):
            max_iter = data_repre.size(0) // batch_size
            if data_repre.size(0) % batch_size != 0:
                max_iter += 1
            order = list(range(data_repre.size(0)))
            random.shuffle(order)
            for i in range(max_iter):            
                x = data_repre[order[i * batch_size : min((i + 1) * batch_size, data_repre.size(0))]]
                # batch_label = label[order[i * batch_size : min((i + 1) * batch_size, data_repre.size(0))]]
                
                # neg sampling
                # ---------------------
                batch_label = torch.ones((x.size(0))).long().cuda()
                neg_size = int(x.size(0) * 1)
                neg = self.neg_loader.next_batch(neg_size)
                neg = self.encode(neg, self.args.infer_batch_size)
                x = torch.cat([x, neg], 0)
                batch_label = torch.cat([batch_label, torch.zeros((neg_size)).long().cuda()], 0)
                # ---------------------
    
                # Relation Classifier
                x = torch.matmul(x, self.new_W) + self.new_bias # (batch_size, 1)
                x = torch.sigmoid(x)
    
                # iter_loss = self.__loss__(x, batch_label.float()).mean()
                weight = torch.ones(batch_label.size(0)).float().cuda()
                weight[batch_label == 0] = self.args.finetune_weight #1 / float(max_epoch)
                iter_loss = (self.__loss__(x, batch_label.float()) * weight).mean()
    
                optimizer.zero_grad()
                iter_loss.backward(retain_graph=True)
                optimizer.step()
                if self.args.print_debug:
                    sys.stdout.write('[snowball finetune] epoch {0:4} iter {1:4} | loss: {2:2.6f}'.format(epoch, i, iter_loss) + '\r')
                    sys.stdout.flush()
        self.eval()
    
    def _add_ins_to_data(self, dataset_dst, dataset_src, ins_id, label=None):
        '''
        add one instance from dataset_src to dataset_dst (list)
        dataset_dst: destination dataset
        dataset_src: source dataset
        ins_id: id of the instance
        '''
        dataset_dst['word'].append(dataset_src['word'][ins_id])
        if 'pos1' in dataset_src:
            dataset_dst['pos1'].append(dataset_src['pos1'][ins_id])
            dataset_dst['pos2'].append(dataset_src['pos2'][ins_id])
        dataset_dst['mask'].append(dataset_src['mask'][ins_id])
        if 'id' in dataset_dst and 'id' in dataset_src:
            dataset_dst['id'].append(dataset_src['id'][ins_id])
        if 'entpair' in dataset_dst and 'entpair' in dataset_src:
            dataset_dst['entpair'].append(dataset_src['entpair'][ins_id])
        if 'label' in dataset_dst and label is not None:
            dataset_dst['label'].append(label)
    
    def _add_ins_to_vdata(self, dataset_dst, dataset_src, ins_id, label=None):
        '''
        add one instance from dataset_src to dataset_dst (variable)
        dataset_dst: destination dataset
        dataset_src: source dataset
        ins_id: id of the instance
        '''
        dataset_dst['word'] = torch.cat([dataset_dst['word'], dataset_src['word'][ins_id].unsqueeze(0)], 0)
        if 'pos1' in dataset_src:
            dataset_dst['pos1'] = torch.cat([dataset_dst['pos1'], dataset_src['pos1'][ins_id].unsqueeze(0)], 0)
            dataset_dst['pos2'] = torch.cat([dataset_dst['pos2'], dataset_src['pos2'][ins_id].unsqueeze(0)], 0)
        dataset_dst['mask'] = torch.cat([dataset_dst['mask'], dataset_src['mask'][ins_id].unsqueeze(0)], 0)
        if 'id' in dataset_dst and 'id' in dataset_src:
            dataset_dst['id'] = torch.cat([dataset_dst['id'], dataset_src['id'][ins_id].unsqueeze(0)], 0)
        if 'entpair' in dataset_dst and 'entpair' in dataset_src:
            dataset_dst['entpair'].append(dataset_src['entpair'][ins_id])
        if 'label' in dataset_dst and label is not None:
            dataset_dst['label'] = torch.cat([dataset_dst['label'], torch.ones((1)).long().cuda()], 0)
    
    def _dataset_stack_and_cuda(self, dataset):
        '''
        stack the dataset to torch.Tensor and use cuda mode
        dataset: target dataset
        '''
        if (len(dataset['word']) == 0):
            return
        dataset['word'] = torch.stack(dataset['word'], 0).cuda()
        if 'pos1' in dataset:
            dataset['pos1'] = torch.stack(dataset['pos1'], 0).cuda()
            dataset['pos2'] = torch.stack(dataset['pos2'], 0).cuda()
        dataset['mask'] = torch.stack(dataset['mask'], 0).cuda()
        dataset['id'] = torch.stack(dataset['id'], 0).cuda()
    
    def encode(self, dataset, batch_size=0):
        if self.pre_rep is not None:
            return self.pre_rep[dataset['id'].view(-1)]
    
        if batch_size == 0:
            x = self.sentence_encoder(dataset)
        else:
            total_length = dataset['word'].size(0)
            max_iter = total_length // batch_size
            if total_length % batch_size != 0:
                max_iter += 1
            x = []
            for it in range(max_iter):
                scope = list(range(batch_size * it, min(batch_size * (it + 1), total_length)))
                with torch.no_grad():
                    _ = {'word': dataset['word'][scope], 'mask': dataset['mask'][scope]}
                    if 'pos1' in dataset:
                        _['pos1'] = dataset['pos1'][scope]
                        _['pos2'] = dataset['pos2'][scope]
                    _x = self.sentence_encoder(_)
                x.append(_x.detach())
            x = torch.cat(x, 0)
        return x
    
    def _infer(self, dataset, batch_size=0):
        '''
        get prob output of the finetune network with the input dataset
        dataset: input dataset
        return: prob output of the finetune network
        '''
        x = self.encode(dataset, batch_size=batch_size) 
        x = torch.matmul(x, self.new_W) + self.new_bias # (batch_size, 1)
        x = torch.sigmoid(x)
        return x.view(-1)
    
    def _forward_train(self, support_pos, query, distant, threshold=0.5):
        '''
        snowball process (train)
        support_pos: support set (positive, raw data)
        support_neg: support set (negative, raw data)
        query: query set
        distant: distant data loader
        threshold: ins with prob > threshold will be classified as positive
        threshold_for_phase1: distant ins with prob > th_for_phase1 will be added to extended support set at phase1
        threshold_for_phase2: distant ins with prob > th_for_phase2 will be added to extended support set at phase2
        '''
    
        # hyperparameters
        snowball_max_iter = self.args.snowball_max_iter
        sys.stdout.flush()
        candidate_num_class = 20
        candidate_num_ins_per_class = 100
        
        sort_num1 = self.args.phase1_add_num
        sort_num2 = self.args.phase2_add_num
        sort_threshold1 = self.args.phase1_siamese_th
        sort_threshold2 = self.args.phase2_siamese_th
        sort_ori_threshold = self.args.phase2_cl_th
    
        # get neg representations with sentence encoder
        # support_neg_rep = self.encode(support_neg, batch_size=self.args.infer_batch_size)
        
        # init
        self._train_finetune_init()
        # support_rep = self.encode(support, self.args.infer_batch_size)
    
        # positive的raw data进行编码以representation
        support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
        # self._train_finetune(support_rep, support['label'])
        self._train_finetune(support_pos_rep)
    
        self._metric = []
    
        # copy
        original_support_pos = copy.deepcopy(support_pos)
    
        # snowball
        exist_id = {}
        if self.args.print_debug:
            print('\n-------------------------------------------------------')
        for snowball_iter in range(snowball_max_iter):
            if self.args.print_debug:
                print('###### snowball iter ' + str(snowball_iter))
            # phase 1: expand positive support set from distant dataset (with same entity pairs)
    
            ## get all entpairs and their ins in positive support set   ins is instance
            old_support_pos_label = support_pos['label'] + 0
            entpair_support = {}
            entpair_distant = {}
            for i in range(len(support_pos['id'])): # only positive support
                entpair = support_pos['entpair'][i] #实体对
                exist_id[support_pos['id'][i]] = 1
                if entpair not in entpair_support:
                    if 'pos1' in support_pos:
                        entpair_support[entpair] = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': []}
                    else:
                        entpair_support[entpair] = {'word': [], 'mask': [], 'id': []}
                self._add_ins_to_data(entpair_support[entpair], support_pos, i)
            
            ## pick all ins with the same entpairs in distant data and choose with siamese network
            self._phase1_add_num = 0 # total number of snowball instances
            self._phase1_total = 0
            for entpair in entpair_support:
                raw = distant.get_same_entpair_ins(entpair) # ins with the same entpair
                if raw is None:
                    continue
                if 'pos1' in support_pos:   #以字典储存
                    entpair_distant[entpair] = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': [], 'entpair': []}
                else:
                    entpair_distant[entpair] = {'word': [], 'mask': [], 'id': [], 'entpair': []}
                for i in range(raw['word'].size(0)):
                    if raw['id'][i] not in exist_id: # don't pick sentences already in the support set
                        self._add_ins_to_data(entpair_distant[entpair], raw, i)
                self._dataset_stack_and_cuda(entpair_support[entpair])
                self._dataset_stack_and_cuda(entpair_distant[entpair])
                if len(entpair_support[entpair]['word']) == 0 or len(entpair_distant[entpair]['word']) == 0:
                    continue
    
                # 比较entpair_support和entpair_distant中的实体对的相似度,决定候选集C1的取舍
                pick_or_not = self.siamese_model.forward_infer_sort(entpair_support[entpair], entpair_distant[entpair], batch_size=self.args.infer_batch_size)
                
                # pick_or_not = self.siamese_model.forward_infer_sort(original_support_pos, entpair_distant[entpair], threshold=threshold_for_phase1)
                # pick_or_not = self._infer(entpair_distant[entpair]) > threshold
      
                # -- method B: use sort --
                for i in range(min(len(pick_or_not), sort_num1)):
                    if pick_or_not[i][0] > sort_threshold1:
                        iid = pick_or_not[i][1]
                        self._add_ins_to_vdata(support_pos, entpair_distant[entpair], iid, label=1)
                        exist_id[entpair_distant[entpair]['id'][iid]] = 1
                        self._phase1_add_num += 1
                self._phase1_total += entpair_distant[entpair]['word'].size(0)
            '''
            if 'pos1' in support_pos:
                candidate = {'word': [], 'pos1': [], 'pos2': [], 'mask': [], 'id': [], 'entpair': []}
            else:
                candidate = {'word': [], 'mask': [], 'id': [], 'entpair': []}
    
            self._phase1_add_num = 0 # total number of snowball instances
            self._phase1_total = 0
            for entpair in entpair_support:
                raw = distant.get_same_entpair_ins(entpair) # ins with the same entpair
                if raw is None:
                    continue
                for i in range(raw['word'].size(0)):
                    if raw['id'][i] not in exist_id: # don't pick sentences already in the support set
                        self._add_ins_to_data(candidate, raw, i)
    
            if len(candidate['word']) > 0:
                self._dataset_stack_and_cuda(candidate)
                pick_or_not = self.siamese_model.forward_infer_sort(support_pos, candidate, batch_size=self.args.infer_batch_size)
                    
                for i in range(min(len(pick_or_not), sort_num1)):
                    if pick_or_not[i][0] > sort_threshold1:
                        iid = pick_or_not[i][1]
                        self._add_ins_to_vdata(support_pos, candidate, iid, label=1)
                        exist_id[candidate['id'][iid]] = 1
                        self._phase1_add_num += 1
                self._phase1_total += candidate['word'].size(0)
            '''
            ## build new support set
            
            # print('---')
            # for i in range(len(support_pos['entpair'])):
            #     print(support_pos['entpair'][i])
            # print('---')
            # print('---')
            # for i in range(support_pos['id'].size(0)):
            #     print(support_pos['id'][i])
            # print('---')
    
            support_pos_rep = self.encode(support_pos, batch_size=self.args.infer_batch_size)
            # support_rep = torch.cat([support_pos_rep, support_neg_rep], 0)
            # support_label = torch.cat([support_pos['label'], support_neg['label']], 0)
            
            ## finetune
            # print("Fine-tune Init")
            self._train_finetune_init()
            self._train_finetune(support_pos_rep)
            if self.args.eval:
                self._forward_eval_binary(query, threshold)
            # self._metric.append(np.array([self._f1, self._prec, self._recall]))
            if self.args.print_debug:
                print('\nphase1 add {} ins / {}'.format(self._phase1_add_num, self._phase1_total))
    
            # phase 2: use the new classifier to pick more extended support ins
            self._phase2_add_num = 0
            candidate = distant.get_random_candidate(self.pos_class, candidate_num_class, candidate_num_ins_per_class)
    
            ## -- method 1: directly use the classifier --
            candidate_prob = self._infer(candidate, batch_size=self.args.infer_batch_size)
            ## -- method 2: use siamese network --
    
            pick_or_not = self.siamese_model.forward_infer_sort(support_pos, candidate, batch_size=self.args.infer_batch_size)
    
            ## -- method A: use threshold --
            '''
            self._phase2_total = candidate_prob.size(0)
            for i in range(candidate_prob.size(0)):
                # if (candidate_prob[i] > threshold_for_phase2) and not (candidate['id'][i] in exist_id):
                if (pick_or_not[i]) and (candidate_prob[i] > threshold_for_phase2) and not (candidate['id'][i] in exist_id):
                    exist_id[candidate['id'][i]] = 1 
                    self._phase2_add_num += 1
                    self._add_ins_to_vdata(support_pos, candidate, i, label=1)
            '''
    
            ## -- method B: use sort --
            self._phase2_total = candidate['word'].size(0)
            for i in range(min(len(candidate_prob), sort_num2)):
                iid = pick_or_not[i][1]
                if (pick_or_not[i][0] > sort_threshold2) and (candidate_prob[iid] > sort_ori_threshold) and not (candidate['id'][iid] in exist_id):
                    exist_id[candidate['id'][iid]] = 1 
                    self._phase2_add_num += 1
                    self._add_ins_to_vdata(support_pos, candidate, iid, label=1)
    
            ## build new support set
            support_pos_rep = self.encode(support_pos, self.args.infer_batch_size)
            # support_rep = torch.cat([support_pos_rep, support_neg_rep], 0)
            # support_label = torch.cat([support_pos['label'], support_neg['label']], 0)
    
            ## finetune
            # print("Fine-tune Init")
            self._train_finetune_init()
            self._train_finetune(support_pos_rep)
            if self.args.eval:
                self._forward_eval_binary(query, threshold)
                self._metric.append(np.array([self._f1, self._prec, self._recall]))
                if self.args.print_debug:
                    print('\nphase2 add {} ins / {}'.format(self._phase2_add_num, self._phase2_total))
    
        self._forward_eval_binary(query, threshold)
        if self.args.print_debug:
            print('\nphase2 add {} ins / {}'.format(self._phase2_add_num, self._phase2_total))
    
        return support_pos_rep
    
    def _forward_eval_binary(self, query, threshold=0.5):
        '''
        snowball process (eval)
        query: query set (raw data)
        threshold: ins with prob > threshold will be classified as positive
        return (accuracy at threshold, precision at threshold, recall at threshold, f1 at threshold, auc), 
        '''
        query_prob = self._infer(query, batch_size=self.args.infer_batch_size).cpu().detach().numpy()
        label = query['label'].cpu().detach().numpy()
        accuracy = float(np.logical_or(np.logical_and(query_prob > threshold, label == 1), np.logical_and(query_prob < threshold, label == 0)).sum()) / float(query_prob.shape[0])
        if (query_prob > threshold).sum() == 0:
            precision = 0
        else:
            precision = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((query_prob > threshold).sum())
        recall = float(np.logical_and(query_prob > threshold, label == 1).sum()) / float((label == 1).sum())
        if precision + recall == 0:
            f1 = 0
        else:
            f1 = float(2.0 * precision * recall) / float(precision + recall)
        auc = sklearn.metrics.roc_auc_score(label, query_prob)
        if self.args.print_debug:
            print('')
            sys.stdout.write('[EVAL] acc: {0:2.2f}%, prec: {1:2.2f}%, rec: {2:2.2f}%, f1: {3:1.3f}, auc: {4:1.3f}'.format(\
                    accuracy * 100, precision * 100, recall * 100, f1, auc) + '\r')
            sys.stdout.flush()
        self._accuracy = accuracy
        self._prec = precision
        self._recall = recall
        self._f1 = f1
        return (accuracy, precision, recall, f1, auc)
    
    def forward(self, support_pos, query, distant, pos_class, threshold=0.5, threshold_for_snowball=0.5):
        '''
        snowball process (train + eval)
        support_pos: support set (positive, raw data)
        support_neg: support set (negative, raw data)
        query: query set (raw data)
        distant: distant data loader
        pos_class: positive relation (name)
        threshold: ins with prob > threshold will be classified as positive
        threshold_for_snowball: distant ins with prob > th_for_snowball will be added to extended support set
        '''
        self.pos_class = pos_class 
    
        self._forward_train(support_pos, query, distant, threshold=threshold)
    
    def init_10shot(self, Ws, bs):
        self.Ws = torch.stack(Ws, 0).transpose(0, 1) # (230, 16)
        self.bs = torch.stack(bs, 0).transpose(0, 1) # (1, 16)
    
    def eval_10shot(self, query):
        x = self.sentence_encoder(query)
        x = torch.matmul(x, self.Ws) + self.new_bias # (batch_size, 16)
        x = torch.sigmoid(x)
        _, pred = x.max(-1) # (batch_size)
        return self.__accuracy__(pred, query['label'])

Reference

AAAI 2020 | Tsinghua University: A Self-Adaptive Diffusion Mechanism for Neural Network Based Sparse Sample Data Prediction Model

全部评论 (0)

还没有任何评论哟~