Advertisement

【论文阅读】GRAPH-BASED RECURRENT RETRIEVER

阅读量:

GRAPH-BASED RECURRENT RETRIEVER

a new graph-based recurrent retrieval method

查找证据文档作为回答复杂问题的推理路径。

w_i=BERT_{CLS}(q,p_i) \in \R^d

P(p_i|h_t)=\sigma(w_i·h_t+b)

h_{t=1}=RNN(h_t,w_i) \in \R^d

b:偏置项

  • 使用RNN建模问题Q的推理路径。
  • 给定问题q,在时间步t时,模型从候选段落集C_t中找出p_i ,与q拼接计算p_i的概率。
  • 遇到[EOE]时结束推理,允许它在给定每个问题的情况下捕获具有任意长度的推理路径。

本文BERT结构:

image-20210108150704041

RNN结构:

P(p_i|h_t):表示在时间步t选择段落p_i的概率。
image-20210108151355616

最终得到推理路径【p_1,p_2

beam search

  • 通过束搜索得到给定时间步长的有限数量的最可能推理路径,减小输入BERT的数据量,减小计算量。
  • C_1是用在输入问题上 TF-IDF 得分最高的段落。
  • C_t是在C_1基础上,拓展的连接段落,用输入到BERT。
  • 推理路径E = [p_i, . . . , p_k]乘段落概率P(p_i|h_1) . . . P(p_k|h_{|E|})得到beam search 的输出,即得到top B 推理路径 E = {E_1, . . . , E_B}作为BERT输入,再将BERT输出作为RNN输入。

BERT相关

Bidirectional Encoder Representations from Transformers

是Google以无监督的方式利用大量无标注文本训练的的语言代表模型,其架构为Transformer中的Encoder。

BERT 里5个特殊tokens:

  1. [CLS]:在做分类任务时其最后一层的repr. 会被视为整个输入序列的repr。

repr指的都是一个可以用来代表某词汇(在某个语境下)的多维连续向量(continuous vector)。

  1. [SEP]:有两个句子的文本会被串接成一个输入序列,并在两句之间插入这个token 以做区隔。
  2. [UNK]:没出现在BERT 字典里头的字会被这个token 取代。
  3. [PAD]:zero padding 遮罩,将长度不一的输入序列补齐方便做batch 运算。
  4. [MASK]:未知遮罩,仅在预训练阶段会用到。

代码实现

加载BERT预训练模型

复制代码
    model = BertForGraphRetriever.from_pretrained(args.bert_model,cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(-1),graph_retriever_config=graph_retriever_config)

默认从缓存中加载,下载之后源码中替换自己本地路径即可。

  • any() 函数用于判断给定的可迭代参数 iterable 是否全部为 False,则返回 False,如果有一个为 True,则返回 True。

使用BertAdam自定义Adam优化器

复制代码
    optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total,
                             max_grad_norm=1.0)

在前10%的steps中,lr从0线性增加到 init_learning_rate,这个阶段又叫 warmup,然后,lr又从 init_learning_rate 线性衰减到0(完成所有steps)。

对问题和段落加上[CLS],[SEP]

复制代码
    def tokenize_question(question, tokenizer):
    tokens_q = tokenizer.tokenize(question)
    tokens_q = ['[CLS]'] + tokens_q + ['[SEP]']
    
    return tokens_q
    
    
    def tokenize_paragraph(p, tokens_q, max_seq_length, tokenizer):
    tokens_p = tokenizer.tokenize(p)[:max_seq_length - len(tokens_q) - 1]
    tokens_p = tokens_p + ['[SEP]']
    
    padding = [0] * (max_seq_length - len(tokens_p) - len(tokens_q))
    
    input_ids_ = tokenizer.convert_tokens_to_ids(tokens_q + tokens_p)
    input_masks_ = [1] * len(input_ids_)
    segment_ids_ = [0] * len(tokens_q) + [1] * len(tokens_p)
    
    input_ids_ += padding
    input_masks_ += padding
    segment_ids_ += padding
    
    assert len(input_ids_) == max_seq_length
    assert len(input_masks_) == max_seq_length
    assert len(segment_ids_) == max_seq_length
    
    return input_ids_, input_masks_, segment_ids_

RNN初始化

复制代码
    self.rw = nn.Linear(2 * config.hidden_size, config.hidden_size)

通过beam search 找出top B 推理路径

复制代码
    b = 0
    while b < beam:
    s, p = torch.max(score.view(score.size(0) * score.size(1)), dim=0)
    s = s.item()
    p = p.item()
    row = p // score.size(1)
    col = p % score.size(1)
    
    if j == 0:
        score[:, col] = 0.0
    else:
        score[row, col] = 0.0
    
    p = [[index for index in pred_[row][0]] + [col],
         output[row].topk(k=2, dim=0)[1].tolist(),
         s]
    new_pred_.append(p)
    
    p = [[p_ for p_ in prb] for prb in prob_[row]] + [output[row].tolist()]
    new_prob_.append(p)
    
    state_tmp[b].copy_(state_[row])
    b += 1

实验

下载程序:

复制代码
    !git clone https://github.com/AkariAsai/learning_to_retrieve_reasoning_paths.git
    %cd /content/learning_to_retrieve_reasoning_paths
    !pip install -r requirements.txt

下载数据集

复制代码
    %cd /content/learning_to_retrieve_reasoning_paths
    !mkdir data
    %cd data
    !mkdir hotpot
    %cd hotpot
    !gdown https://drive.google.com/uc?id=1AIRo66I2Izs80nNLt4MaLu7kqhTuIQ0u
    !unzip hotpotqa_new_selector_train_data_db_2017_10_12_fix.zip.zip____
    !rm hotpotqa_new_selector_train_data_db_2017_10_12_fix.zip.zip____

训练模型

复制代码
    %cd /content/learning_to_retrieve_reasoning_paths/graph_retriever
复制代码
    !python3 run_graph_retriever.py \--task hotpot_distractor \--bert_model bert-base-uncased --do_lower_case \--train_file_path /content/learning_to_retrieve_reasoning_paths/data/hotpot/hotpotqa_new_selector_train_data_db_2017_10_12_fix/db=wiki_hotpotqa.db_hotpotqa_new_test_tfidf_k=50.pruning_l=100_tag_me=True.prune_after_agg=False.prune_in_article=False_use_link=True_start=0_end=5000.json \--output_dir graph_retriever/ \--max_para_num 10 \--neg_chunk 8 --train_batch_size 4 --gradient_accumulation_steps 4 \--learning_rate 3e-5 --num_train_epochs 3 \--max_select_num 3

–max_para_num:与问题相关的段落数量。如果–max_para_num是n,问题的基础真实段落数量是k(2),那么有n-2个段落作为训练的反例。此时反例数量为8。

–neg_chunk:为了控制GPU内存消耗,将负例拆分为小块。

–max_select_num:指定模型推理步骤的最大数量,如果问题的基础真实段落数量是k,这个值应该为k+1,1表示结束符号EOE,此时k+1=3。

全部评论 (0)

还没有任何评论哟~