Advertisement

python+tensorflow中文文本多标签分类

阅读量:

整体流程可以分为以下流程:

1:加载数据集

2:数据预处理

3:搭建神经网络模型

4:训练数据模型

1:加载数据集

从数据库或excel、csv数据表中读取中文文本及其对应的标签数据

2:数据预处理

1.对文本数据进行格式化:去重不必要的特殊符号、删除缺失值等

复制代码
 def get_short_title(title):

    
     '''
    
     获取去掉集号的短标题
    
     :param title:
    
     :return:
    
     '''
    
     title = title.replace('(', '(').replace(')', ')').replace('()', '').replace(' ', '')
    
     r = re.findall(r"^(.*?)(?:(?:\([0-9ⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩ一二三四五六七八九十上中下,,+、/\-]+\))|(?:[0-9ⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩ,,、/\-]+)|(?:第*[0-9ⅠⅡⅢⅣⅤⅥⅦⅧⅨⅩ一二三四五六七八九十,,+、/\-]+(?:期|集))|(?:\(大结局\))|(?:\(直播\))|(?:\(首播\))|(?:\(重播\))|(?:\(重1\))|(?:hd))*$", title)
    
     if len(r) == 1:
    
     result = r[0].strip()
    
     else:
    
     result = title.strip()
    
     m1 = re.search(r"^.*?\(*[0-9]*\)*\([0-9]{1,2}:[0-9]{2}.*?\)$", result)
    
     if m1 is not None:
    
     result = re.sub(r"(^.*?)(\(*[0-9]*\)*)(\([0-9]{1,2}:[0-9]{2})(.*?\)$)", r"\1", result)
    
  
    
     m2 = re.search(r"^.*?\([0-9]*\):.*?$", result)
    
     if m2 is not None:
    
     result = re.sub(r"(^.*?)(\([0-9]*\))(:.*?)$", r"\1\3", result)
    
     return result
    
    
    
    
    AI写代码
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/dR2N6oj5tEHzAblqDkrSsfuV89Cw.png)

2.统计所有文本的标签类型,并对标签进行标注下标

复制代码
 tag_index_dict = {}  # 标签为key,index下标为value的字典

    
 tag_index_list = []  # 包含所有的标签列表
    
 index = 0
    
 for title in title_tags_dict:   # title_tags_dict:文本为key,对应的标签为value的字典
    
     if title_tags_dict[title]== []:
    
     continue
    
     else:
    
     for tags in title_tags_dict[title]:
    
         if tags in tag_index_list:
    
             continue
    
         else:
    
             tag_index_dict[tags] = index
    
             tag_index_list.append(tags)
    
             index += 1
    
    
    
    
    AI写代码
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/gnzIEeia8mf5rU1bOR0KuthMj9Yc.png)

3.将所有文本及对应标签进行one-hot编码同时分割80%的数据为train.txt训练文件和20%validate.txt验证文件

复制代码
 with open(ABSOLUTE_FILE_PATH + "/train.txt".format(type=self.type), mode="w", encoding="utf-8",newline="") as train_txt_file, \

    
 open(ABSOLUTE_FILE_PATH + "/validate.txt".format(type=self.type), mode="w",
    
                  encoding="utf-8",newline="") as validate_txt_file:
    
     count = 1
    
     for title in title_tags_dict:
    
     label = [0] * len(tag_index_list)
    
     tags_name = []
    
     if self.title_tags_dict[title] == []:
    
         continue
    
     else:
    
         tags_list = title_tags_dict[title]
    
     if tags_list==[]:
    
         continue
    
     for tag in tags_list:
    
         tags_name.append(tag)
    
         label[self.data["tag_index_dict"].get(tag)] = 1
    
    if count % 20 == 0:
    
        txt_file = validate_txt_file
    
    else:
    
        txt_file = train_txt_file
    
    txt_file.write(title.replace("|", "") + "|" + json.dumps(tags_name, ensure_ascii=False) + "|" + json.dumps(label)+"\n")
    
    count += 1
    
    
    
    
    AI写代码
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/gMf56lhTdVCcZeIqHA1mxiQn0J2k.png)

最终txt文件呈现

复制代码
 世锦赛八强确定两席 中国女排恐需决战比利时|["体育", "排球", "新闻"]|[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

    
 中国体育非凡十年:新一代运动员展现中国风采|["体育", "新闻", "纪录片"]|[1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    
 中超联赛:武汉长江遭遇五连败|["中超", "体育", "新闻", "足球"]|[1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    
 女排世锦赛:中国女排轻取波多黎各|["体育", "排球", "新闻"]|[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    
    
    
    

3:搭建神经网络模型

复制代码
 def build_graph(class_size, is_train):

    
     '''
    
     构建模型计算图
    
     :param max_words_length:
    
     :param vocabulary_size:
    
     :return:
    
     '''
    
     if is_train:
    
     text_cnn = TextCNN(W=word2vec.get_W(), sequence_length=max_words_length, num_classes=class_size,
    
                        embedding_size=50, filter_sizes=[2, 3, 4], num_filters=128, other_features_length=4,
    
                        l2_reg_lambda=0)
    
     else:
    
     text_cnn = TextCNN(W=np.zeros(shape=[word_2_vec_length, 50]), sequence_length=max_words_length,
    
                        num_classes=class_size,
    
                        embedding_size=50, filter_sizes=[2, 3, 4], num_filters=128, other_features_length=4,
    
                        l2_reg_lambda=0)
    
     predictions = tf.compat.v1.cast(tf.compat.v1.greater(text_cnn.prob, 0.5), tf.compat.v1.float32)
    
     correct_predictions = tf.compat.v1.cast(tf.compat.v1.equal(tf.compat.v1.cast(text_cnn.input_y, tf.compat.v1.int32), tf.compat.v1.cast(predictions, tf.compat.v1.int32)),
    
                               tf.compat.v1.float32)
    
     correct_predictions_mask = correct_predictions * text_cnn.y_mask
    
     accuracy_batch = tf.compat.v1.reduce_sum(correct_predictions_mask, 1) / tf.compat.v1.reduce_sum(text_cnn.y_mask, 1)
    
     accuracy = tf.compat.v1.reduce_mean(accuracy_batch)
    
  
    
     graph = {
    
     "prob": text_cnn.prob,
    
     "loss": text_cnn.loss,
    
     "input_x": text_cnn.input_x,
    
     "input_y": text_cnn.input_y,
    
     "y_mask": text_cnn.y_mask,
    
     "other_features": text_cnn.other_features,
    
     "dropout_keep_prob": text_cnn.dropout_keep_prob,
    
     "accuracy": accuracy
    
     }
    
     return graph
    
    
    
    
    AI写代码
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/7kzBX532q4tJMPW8ldTwKA0Ss9mr.png)

4:训练数据模型

复制代码
 def train(graph, model_dir, dh, is_train, new_model):

    
     loss = graph["loss"]
    
     prob = graph["prob"]
    
     input_x = graph["input_x"]
    
     input_y = graph["input_y"]
    
     y_mask = graph["y_mask"]
    
     accuracy = graph["accuracy"]
    
     other_features = graph["other_features"]
    
     dropout_keep_prob = graph["dropout_keep_prob"]
    
     train_loss_summary = tf.summary.scalar('train_loss', loss)
    
     validate_loss_summary = tf.summary.scalar('validate_loss', loss)
    
     train_accuracy_summary = tf.summary.scalar('train_accuracy', accuracy)
    
     validate_accuracy_summary = tf.summary.scalar('validate_accuracy', accuracy)
    
     optimizer = tf.compat.v1.train.AdamOptimizer()
    
     train_op = optimizer.minimize(loss)
    
     init = tf.compat.v1.global_variables_initializer()
    
     saver = tf.compat.v1.train.Saver()
    
     sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
    
     inter_op_parallelism_threads=2,
    
     intra_op_parallelism_threads=2,
    
     ))
    
     train_writer = tf.compat.v1.summary.FileWriter('./summary',
    
                                      sess.graph)
    
     sess.run(init)
    
     if is_train and new_model:
    
     sess.run(init)
    
     else:
    
     saver.restore(sess, ABSOLUTE_FILE_PATH + "/data/models/model.ckpt")
    
     if is_train:
    
     for i in range(10000):
    
         train_title_batch, train_title_word_index_batch, train_other_feature_batch, train_label_batch, train_label_mask_batch = dh.get_batch(
    
             is_train=True,
    
             batch_size=batch_size)
    
         train_op_val, loss_val, train_loss_summary_val, train_accuracy_summary_val = sess.run(
    
             [train_op, loss, train_loss_summary, train_accuracy_summary], feed_dict={
    
                 input_x: train_title_word_index_batch,
    
                 input_y: train_label_batch,
    
                 y_mask: train_label_mask_batch,
    
                 other_features: train_other_feature_batch,
    
                 dropout_keep_prob: 0.5
    
             })
    
         logging.info("{} loss: {}".format(i, loss_val))
    
         train_writer.add_summary(train_loss_summary_val, i)
    
         train_writer.add_summary(train_accuracy_summary_val, i)
    
         # 每100步训练,验证一次
    
         if i % 100 == 0:
    
             validate_title_batch, validate_title_word_index_batch, validate_other_feature_batch, validate_label_batch, validate_label_mask_batch = dh.get_batch(
    
                 is_train=False,
    
                 batch_size=batch_size)
    
             loss_val, validate_loss_summary_val, validate_accuracy_summary_val = sess.run(
    
                 [loss, validate_loss_summary, validate_accuracy_summary], feed_dict={
    
                     input_x: validate_title_word_index_batch,
    
                     input_y: validate_label_batch,
    
                     y_mask: validate_label_mask_batch,
    
                     other_features: validate_other_feature_batch,
    
                     dropout_keep_prob: 1.0
    
                 })
    
             logging.info("--------------------------- validate loss: {} ---------------------------".format(loss_val))
    
             train_writer.add_summary(validate_loss_summary_val, i)
    
             train_writer.add_summary(validate_accuracy_summary_val, i)
    
             saver.save(sess, model_dir)
    
     # 训练完成后,进行结果评价
    
     validate_title_batch, validate_title_word_index_batch, validate_other_feature_batch, validate_label_batch, validate_label_mask_batch = dh.get_batch(
    
     is_train=False,
    
     batch_size=batch_size)
    
     loss_val, prob_val = sess.run([loss, prob], feed_dict={
    
     input_x: validate_title_word_index_batch,
    
     input_y: validate_label_batch,
    
     y_mask: validate_label_mask_batch,
    
     other_features: validate_other_feature_batch,
    
     dropout_keep_prob: 1.0
    
     })
    
     aps = []
    
     prob_sort_index_array = np.argsort(prob_val)
    
     for k in range(prob_sort_index_array.shape[0]):
    
     pred_label = np.array(list(map(lambda item: 1 if item >= 0.5 else 0, prob_val[k])))
    
     truth_labels = [dh.data["index_tag_list"][i] for i in range(len(validate_label_batch[k])) if
    
                     validate_label_batch[k][i] != 0]
    
     ap = sklearn.metrics.average_precision_score(validate_label_batch[k], prob_val[k])
    
     precision = sklearn.metrics.precision_score(validate_label_batch[k], pred_label, average='binary')
    
     recall = sklearn.metrics.recall_score(validate_label_batch[k],pred_label, average='binary')
    
     logging.info("precision: {}".format(precision))
    
     logging.info("recall: {}".format(recall))
    
     logging.info("ap {}".format(ap))
    
     aps.append(ap)
    
  
    
     logging.info("loss_val: {}".format(loss_val))
    
     logging.info(np.mean(np.asarray(aps)))
    
    
    
    
    AI写代码
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-05-30/j7qG8yPEVHwoflKsaLe2iY4cvC3U.png)

先写到这里,底层算法还不是很明白,只能写一个框架流程

全部评论 (0)

还没有任何评论哟~