Advertisement

Tensorflow1.x实现BiLstm+CRF

阅读量:

这段文本详细介绍了基于BiLSTM-CRF模型进行中文命名实体识别(NER)任务的实现过程,包括以下内容:
CRF层的作用:解释了CRF层在序列标签化任务中的重要性,并提及其在TF中通过tf.contrib.crf.crf_decode函数实现。
BiLSTM-CRF模型结构:定义了整个BiLSTM-CRF模型的各个组件:

  • EmbeddingLayer
  • LookUpTable(词表处理)
  • Bilstm(双 LSTM 层)
  • Linear(全连接层)
  • CrfParam(CRF参数)
    数据预处理:提供了预处理代码:
  • 生成词典
  • 将文本转换为id形式并保存为pkl文件
  • 数据分割与格式化
    数据集构建类:定义了一个DataBuilder类来管理数据集的生成与管理。
    模型训练与评估:使用TensorFlow Estimator API实现了模型的定义、训练和评估:
  • 定义了输入函数
  • 使用Adam优化器进行优化
  • 计算损失函数、准确率等指标
  • 提供预测输出
    导出与部署:展示了如何将训练好的模型导出为pb格式和ckpt格式以便部署。
    脚本运行示例:提供了主程序代码示例,并说明了如何运行该项目。
    总结来说,这是一个完整的基于TF实现的BiLSTM-CRF NER项目的文档与代码实现。

在前面章节中,我们对CRF层的主要作用进行了概述,并详细阐述了其损失函数的设计与实现。

BiLSTM中的CRF层(一)简介

BiLSTM中的CRF层(二)CRF层

BiLSTM中的CRF层(三)CRF损失函数

在"万创杯"中医药大数据竞赛中进行中药说明书实体识别的挑战数据集上进行训练和优化,并基于此完成中药命名实体识别(NER)任务。

1.bilstm+crf模型

该文件定义了embedding层,bilstm层,全链接层,crf层等模型。

复制代码
 # -*- coding: utf-8 -*-

    
 # @Time    : 2020-10-09 21:15
    
 # @Author  : xudong
    
 # @email   : dongxu222mk@163.com
    
 # @Site    : 
    
 # @File    : bilstm_crf.py
    
 # @Software: PyCharm
    
  
    
 import tensorflow as tf
    
 from tensorflow.contrib.rnn import LSTMCell
    
 from tensorflow.contrib.rnn import MultiRNNCell
    
  
    
  
    
 class Linear:
    
     """
    
     全链接层
    
     """
    
     def __init__(self, scope_name, input_size, output_size,
    
              drop_out=0., trainable=True):
    
     with tf.variable_scope(scope_name):
    
         self.W = tf.get_variable('W', [input_size, output_size],
    
                             initializer=tf.random_uniform_initializer(-0.25, 0.25),
    
                             trainable=trainable)
    
  
    
         self.b = tf.get_variable('b', [output_size],
    
                             initializer=tf.zeros_initializer(),
    
                             trainable=trainable)
    
  
    
     self.drop_out = tf.layers.Dropout(drop_out)
    
  
    
     self.output_size = output_size
    
  
    
     def __call__(self, inputs, training):
    
     size = tf.shape(inputs)
    
     input_trans = tf.reshape(inputs, [-1, size[-1]])
    
     input_trans = tf.nn.xw_plus_b(input_trans, self.W, self.b)
    
     input_trans = self.drop_out(input_trans, training=training)
    
  
    
     input_trans = tf.reshape(input_trans, [-1, size[1], self.output_size])
    
  
    
     return input_trans
    
  
    
  
    
 class LookupTable:
    
     """
    
     embedding layer
    
     """
    
     def __init__(self, scope_name, vocab_size, embed_size, reuse=False, trainable=True):
    
     self.vocab_size = vocab_size
    
     self.embed_size = embed_size
    
  
    
     with tf.variable_scope(scope_name, reuse=bool(reuse)):
    
         self.embedding = tf.get_variable('embedding', [vocab_size, embed_size],
    
                                          initializer=tf.random_uniform_initializer(-0.25, 0.25),
    
                                          trainable=trainable)
    
  
    
     def __call__(self, input):
    
     input = tf.where(tf.less(input, self.vocab_size), input, tf.ones_like(input))
    
     return tf.nn.embedding_lookup(self.embedding, input)
    
  
    
  
    
 class LstmBase:
    
     """
    
     build rnn cell
    
     """
    
     def build_rnn(self, hidden_size, num_layes):
    
     cells = []
    
     for i in range(num_layes):
    
         cell = LSTMCell(num_units=hidden_size,
    
                         state_is_tuple=True,
    
                         initializer=tf.random_uniform_initializer(-0.25, 0.25))
    
         cells.append(cell)
    
     cells = MultiRNNCell(cells, state_is_tuple=True)
    
  
    
     return cells
    
  
    
  
    
 class BiLstm(LstmBase):
    
     """
    
     define the lstm
    
     """
    
     def __init__(self, scope_name, hidden_size, num_layers):
    
     super(BiLstm, self).__init__()
    
     assert hidden_size % 2 == 0
    
     hidden_size /= 2
    
  
    
     self.fw_rnns = []
    
     self.bw_rnns = []
    
     for i in range(num_layers):
    
         self.fw_rnns.append(self.build_rnn(hidden_size, 1))
    
         self.bw_rnns.append(self.build_rnn(hidden_size, 1))
    
  
    
     self.scope_name = scope_name
    
  
    
     def __call__(self, input, input_len):
    
     for idx, (fw_rnn, bw_rnn) in enumerate(zip(self.fw_rnns, self.bw_rnns)):
    
         scope_name = '{}_{}'.format(self.scope_name, idx)
    
         ctx, _ = tf.nn.bidirectional_dynamic_rnn(
    
             fw_rnn, bw_rnn, input, sequence_length=input_len,
    
             dtype=tf.float32, time_major=False,
    
             scope=scope_name
    
         )
    
         input = tf.concat(ctx, -1)
    
     ctx = input
    
     return ctx
    
  
    
  
    
 class BiLstm_Crf:
    
     def __init__(self, args, vocab_size, emb_size):
    
     # embedding
    
     scope_name = 'look_up'
    
     self.lookuptables = LookupTable(scope_name, vocab_size, emb_size)
    
  
    
     # rnn
    
     scope_name = 'bi_lstm'
    
     self.rnn = BiLstm(scope_name, args.hidden_dim, 1)
    
  
    
     # linear
    
     scope_name = 'linear'
    
     self.linear = Linear(scope_name, args.hidden_dim, args.num_tags,
    
                          drop_out=args.drop_out)
    
  
    
     # crf
    
     scope_name = 'crf_param'
    
     self.crf_param = tf.get_variable(scope_name, [args.num_tags, args.num_tags],
    
                                      dtype=tf.float32)
    
  
    
     def __call__(self, inputs, training):
    
     masks = tf.sign(inputs)
    
     sent_len = tf.reduce_sum(masks, axis=1)
    
  
    
     embedding = self.lookuptables(inputs)
    
  
    
     rnn_out = self.rnn(embedding, sent_len)
    
  
    
     logits = self.linear(rnn_out, training)
    
  
    
     pred_ids, _ = tf.contrib.crf.crf_decode(logits, self.crf_param, sent_len)
    
  
    
     return logits, pred_ids, self.crf_param

2.数据预处理

此份文件属于数据预处理环节,在实际操作中,则是对原始文本中的每个词汇进行唯一标识符的转换,并以pkl格式存储。

复制代码
 # -*- coding: utf-8 -*-

    
 # @Time    : 2020-10-11 18:52
    
 # @Author  : xudong
    
 # @email   : dongxu222mk@163.com
    
 # @Site    : 
    
 # @File    : preprocess.py
    
 # @Software: PyCharm
    
 import os
    
 import _pickle as cPickle
    
 import pandas as pd
    
 import random
    
  
    
 """
    
 数据前处理
    
 将数据处理成id,并封装成pkl形式
    
 """
    
  
    
 tag_list = ['DRUG', 'DRUG_INGREDIENT',
    
         'DISEASE', 'SYMPTOM',
    
         'SYNDROME', 'DISEASE_GROUP',
    
         'FOOD_GROUP', 'FOOD',
    
         'PERSON_GROUP', 'DRUG_GROUP',
    
         'DRUG_DOSAGE', 'DRUG_TASTE',
    
         'DRUG_EFFICACY']
    
 tag_dict = {'O': 0}
    
  
    
 for tag in tag_list:
    
     tag_B = 'B-' + tag
    
     tag_I = 'I-' + tag
    
     tag_dict[tag_B] = len(tag_dict)
    
     tag_dict[tag_I] = len(tag_dict)
    
  
    
 print(tag_dict)
    
  
    
  
    
 def make_vocab(file_path):
    
     """
    
     构建词典
    
     :param file_path:
    
     :return:
    
     """
    
     data = pd.read_csv(file_path, sep='\t', header=None)
    
     data.columns = ['text', 'tag']
    
     vocab = {'PAD': 0, 'UNK': 1}
    
     words_list = []
    
     for index, row in data.iterrows():
    
     text = row['text']
    
     words = text.split('<#>')
    
     for word in words:
    
         words_list.append(word)
    
  
    
     random.shuffle(words_list)
    
     for word in words_list:
    
     if word not in vocab:
    
         vocab[word] = len(vocab)
    
     return vocab
    
  
    
  
    
 def make_data(file_path, vocab):
    
     """
    
     构建数据
    
     :param file_path:
    
     :param vocab
    
     :return:
    
     """
    
     data = pd.read_csv(file_path, sep='\t', header=None)
    
     data.columns = ['text', 'tag']
    
     word_ids = []
    
     tag_ids = []
    
     for index, row in data.iterrows():
    
     text = row['text']
    
     tag_str = row['tag']
    
  
    
     tags = tag_str.split('<#>')
    
     # todo 需要按照逗号来继续分割
    
     words_sep = text.split('<#>。<#>')
    
  
    
     cnt = 0
    
     for word_text in words_sep:
    
         words = word_text.split('<#>')
    
         word_id = [vocab.get(word) if word in vocab else 1 for word in words]
    
         tag_id = [tag_dict.get(tag) for tag in tags[cnt:cnt+len(words)]]
    
  
    
         word_ids.append(word_id)
    
         tag_ids.append(tag_id)
    
         cnt = cnt + len(words) + 1
    
  
    
  
    
     return {'words': word_ids, 'tags': tag_ids}
    
  
    
  
    
 def save_vocab(vocab, output):
    
     """
    
     save vocab dict
    
     :param vocab:
    
     :param output:
    
     :return:
    
     """
    
     with open(output, 'w', encoding='utf-8') as fr:
    
     for word in vocab:
    
         fr.write(word + '\t' + str(vocab.get(word)) + '\n')
    
     print('save vocab is ok.')
    
  
    
  
    
 def main(output_path):
    
     """
    
     main method
    
     :param output_path:
    
     :return:
    
     """
    
     data = {}
    
     train_path = './data_path/train.txt'
    
     test_path = './data_path/test.txt'
    
     vocab = make_vocab(train_path)
    
     train_data = make_data(train_path, vocab)
    
     test_data = make_data(test_path, vocab)
    
  
    
     data['train'] = train_data
    
     data['test'] = test_data
    
  
    
     data_path = os.path.join(output_path, 'ner_data.pkl')
    
     cPickle.dump(data, open(data_path, 'wb'), protocol=2)
    
     print('save data to pkl ok.')
    
  
    
     vocab_path = os.path.join(output_path, 'ner_vocab.txt')
    
     save_vocab(vocab, vocab_path)
    
  
    
  
    
 if __name__ == '__main__':
    
     output = './data_path/'
    
     main(output)
    
     data = cPickle.load(open('./data_path/ner_data.pkl', 'rb'))
    
  
    
     print(data['train']['words'][0])
    
     print(data['train']['tags'][0])

3.数据集构建类

应搭建或制定模型训练所需遵循的具体数据格式

复制代码
 # -*- coding: utf-8 -*-

    
 # @Time    : 2020-10-09 21:18
    
 # @Author  : xudong
    
 # @email   : dongxu222mk@163.com
    
 # @Site    : 
    
 # @File    : datasets.py
    
 # @Software: PyCharm
    
 import numpy as np
    
 import tensorflow as tf
    
  
    
 """
    
 数据集构建类
    
 """
    
  
    
  
    
 class DataBuilder:
    
     def __init__(self, data):
    
     self.words = np.asarray(data['words'])
    
     self.tags = np.asarray(data['tags'])
    
  
    
     @property
    
     def size(self):
    
     return len(self.words)
    
  
    
     def build_generator(self):
    
     """
    
     build data generator for model
    
     :return:
    
     """
    
     for word, tag in zip(self.words, self.tags):
    
         yield (word, len(word)), tag
    
  
    
     def build_dataset(self):
    
     """
    
     build dataset from generator
    
     :return:
    
     """
    
     dataset = tf.data.Dataset.from_generator(
    
         self.build_generator,
    
         ((tf.int64, tf.int64), tf.int64),
    
         ((tf.TensorShape([None]), tf.TensorShape([])), tf.TensorShape([None]))
    
     )
    
     return dataset
    
  
    
     def get_train_batch(self, dataset, batch_size, epoch):
    
     """
    
     get one batch train data
    
     :param dataset:
    
     :param batch_size:
    
     :param epoch:
    
     :return:
    
     """
    
     dataset = dataset.cache()\
    
         .shuffle(buffer_size=10000)\
    
         .padded_batch(batch_size, padded_shapes=(([None], []), [None]))\
    
         .repeat(epoch)
    
     return dataset.make_one_shot_iterator().get_next()
    
  
    
     def get_test_batch(self, dataset, batch_size):
    
     """
    
     get one batch test data
    
     :param dataset:
    
     :param batch_size:
    
     :return:
    
     """
    
     dataset = dataset.padded_batch(batch_size,
    
                                    padded_shapes=(([None], []), [None]))
    
     return dataset.make_one_shot_iterator().get_next()

3.模型训练

首先建立模型参数的结构并设计其训练流程;随后完成对模型的完整保存操作,并将其完整地保存为pb格式和ckpt格式。

复制代码
 # -*- coding: utf-8 -*-

    
 # @Time    : 2020-10-09 23:07
    
 # @Author  : xudong
    
 # @email   : dongxu222mk@163.com
    
 # @Site    : 
    
 # @File    : ner_main.py
    
 # @Software: PyCharm
    
 import sys
    
 import os
    
 import time
    
 import tensorflow as tf
    
 from data_utils import datasets
    
  
    
 import _pickle as cPickle
    
  
    
 from argparse import ArgumentParser
    
 from models.bilstm_crf import BiLstm_Crf
    
  
    
 parser = ArgumentParser()
    
  
    
 parser.add_argument("--vocab_size", type=int, default=2500, help='vocab size')
    
 parser.add_argument("--emb_size", type=int, default=300, help='emb size')
    
 parser.add_argument("--train_path", type=str, default='./data_path/ner_data.pkl')
    
 parser.add_argument("--test_path", type=str, default='./data_path/ner_data.pkl')
    
 parser.add_argument("--model_dir", type=str, default='./model_ckpt/')
    
 parser.add_argument("--model_export", type=str, default='./model_pb')
    
 parser.add_argument("--hidden_dim", type=int, default=300)
    
 parser.add_argument("--num_tags", type=int, default=27)
    
 parser.add_argument("--drop_out", type=float, default=0.1)
    
 parser.add_argument("--batch_size", type=int, default=16)
    
 parser.add_argument("--epoch", type=int, default=1)
    
 parser.add_argument("--type", type=str, default='lstm', help='[lstm,textcnn...]')
    
  
    
  
    
 tf.logging.set_verbosity(tf.logging.INFO)
    
 ARGS, unparsed = parser.parse_known_args()
    
 print(ARGS)
    
  
    
 sys.stdout.flush()
    
  
    
  
    
 def init_data(file_name, type=None):
    
     """
    
     init data
    
     :param file_name:
    
     :param type:
    
     :return:
    
     """
    
     data = cPickle.load(open(file_name, 'rb'))[type]
    
  
    
     data_builder = datasets.DataBuilder(data)
    
     dataset = data_builder.build_dataset()
    
  
    
     def train_input():
    
     return data_builder.get_train_batch(dataset, ARGS.batch_size, ARGS.epoch)
    
  
    
     def test_input():
    
     return data_builder.get_test_batch(dataset, ARGS.batch_size)
    
  
    
     return train_input if type == 'train' else test_input
    
  
    
  
    
 def make_model():
    
     """
    
     build model
    
     :return:
    
     """
    
     vocab_size = ARGS.vocab_size
    
     emb_size = ARGS.emb_size
    
  
    
     if ARGS.type == 'lstm':
    
     model = BiLstm_Crf(ARGS, vocab_size, emb_size)
    
     else:
    
     pass
    
  
    
     return model
    
  
    
  
    
 def model_fn(features, labels, mode, params):
    
     """
    
     build model fn
    
     :return:
    
     """
    
     model = make_model()
    
  
    
     if isinstance(features, dict):
    
     features = features['words'], features['words_len']
    
  
    
     words, words_len = features
    
  
    
     if mode == tf.estimator.ModeKeys.PREDICT:
    
     _, pred_ids, _ = model(words, training=False)
    
  
    
     prediction = {'tag_ids': tf.identity(pred_ids, name='tag_ids')}
    
  
    
     return tf.estimator.EstimatorSpec(
    
         mode=mode,
    
         predictions=prediction,
    
         export_outputs={'classify':tf.estimator.export.PredictOutput(prediction)}
    
     )
    
     else:
    
     tags = labels
    
     weights = tf.sequence_mask(words_len)
    
     if mode == tf.estimator.ModeKeys.TRAIN:
    
         logits, pred_ids, crf_params = model(words, training=True)
    
  
    
         log_like_lihood, _ = tf.contrib.crf.crf_log_likelihood(
    
             logits, tags, words_len, crf_params
    
         )
    
         loss = -tf.reduce_mean(log_like_lihood)
    
         accuracy = tf.metrics.accuracy(tags, pred_ids, weights)
    
  
    
         tf.identity(accuracy[1], name='train_accuracy')
    
         tf.summary.scalar('train_accuracy', accuracy[1])
    
         optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
    
         return tf.estimator.EstimatorSpec(
    
             mode=mode,
    
             loss=loss,
    
             train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step())
    
         )
    
     else:
    
         _, pred_ids, _ = model(words, training=False)
    
         accuracy = tf.metrics.accuracy(tags, pred_ids, weights)
    
         metrics = {
    
             'accuracy': accuracy
    
         }
    
         return tf.estimator.EstimatorSpec(
    
             mode=mode,
    
             loss=tf.constant(0),
    
             eval_metric_ops=metrics
    
         )
    
  
    
  
    
 def main_es(unparsed):
    
     """
    
     main method
    
     :param unparsed:
    
     :return:
    
     """
    
     cur_time = time.time()
    
     model_dir = ARGS.model_dir + str(int(cur_time))
    
  
    
     classifer = tf.estimator.Estimator(
    
     model_fn=model_fn,
    
     model_dir=model_dir,
    
     params={}
    
     )
    
  
    
     # train
    
     train_input = init_data(ARGS.train_path, 'train')
    
     tensors_to_log = {'train_accuracy':'train_accuracy'}
    
     logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=100)
    
     classifer.train(input_fn=train_input, hooks=[logging_hook])
    
  
    
     # eval
    
     test_input = init_data(ARGS.test_path, 'test')
    
     eval_res = classifer.evaluate(input_fn=test_input)
    
     print(f'Evaluation res is : \n\t{eval_res}')
    
  
    
  
    
     if ARGS.export_dir:
    
     words = tf.placeholder(tf.int64, [None, None], name='input_words')
    
     words_len = tf.placeholder(tf.int64, [None], name='input_len')
    
     input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
    
         'words': words,
    
         'words_len': words_len
    
     })
    
     path = os.path.join(ARGS.export_dir, str(int(cur_time)))
    
     classifer.export_savedmodel(path, input_fn)
    
  
    
  
    
 if __name__ == '__main__':
    
     tf.app.run(main=main_es, argv=[sys.argv[0]] + unparsed)

相关代码及数据已完成打包上传至平台,请访问以下获取入口:《基于BiLSTM+CRF的命名实体识别代码》

有空的话会上传到github,也可以留下邮箱,私发到位~~~~

全部评论 (0)

还没有任何评论哟~