Advertisement

Paddle2.0实现中文新闻文本标题分类

阅读量:

Paddle2.0实现中文新闻文本标题分类

  • 中文新闻文本标题分类Paddle2.0版本基线(非官方)

      • 调优小建议
      • 数据集地址
    • 任务描述

      • 数据说明
      • 提交答案
    • 代码思路说明

      • 数据集解压
  • 数据处理流程

  • 数据获取操作(通过字典和数据集实现)

  • 数据初始化阶段

  • 数据分析阶段(包括原始数据查看)

  • 数据增强操作(基于原有样本生成新样本)

  • 特征提取与封装流程

  • 模型架构设计部分

  • 参数优化与模型训练阶段

  • 在推理过程中对输入进行的数据读取操作

  • 启动推理过程

  • 项目负责人简介部分

项目说明:本项目是李宏毅老师在基于飞桨平台的课程中提供的作业解析

课程详情:您可访问传送门进行学习

项目详情:具体的项目介绍可参考传送门

数据集信息:关于数据集的详细资料请访问传送门

本项目仅用于参考,提供思路和想法并非标准答案!请谨慎抄袭!

中文新闻文本标题分类Paddle2.0版本基线(非官方)

非官方,三岁出品!(虽水必精)

调优小建议

本项目基线的值不会很高,需要自行调参来提高效果。
优化建议:

*优化模型 现在采用的是线性模型 也可以考虑引入更复杂的技术 对于NLP项目而言更加友好 (具体情况我不太了解)

通过优化学习率来优化我们最佳效果的定位。经过对现有模型进行持续性训练可以获得显著的效果

数据集地址

https://aistudio.baidu.com/aistudio/datasetdetail/75812

任务描述

基于该数据集进行的文本分类任务中涉及的THUCNews数据集由新浪新闻RSS订阅频道在2005至2011年间的历史数据经过筛选和过滤处理而成。该集合共包含了共计74万篇新闻文档。参赛者需利用算法分析各条新闻标题内容以确定其所属类别

数据说明

基于新浪新闻RSS订阅频道2005至2011年的历史数据进行筛选和过滤处理后生成的THUCNews系统中包含了共计74万篇新闻文档(总大小约2.19 GB),所有文档均为UTF-8标准下的纯文本格式。该系统在原有新浪新闻分类体系的基础上进行了优化整合,并在此基础上重新划分出以下共14个候选分类类别:财经类等

已遵循"标签ID字段\t\t标签字段\t\t原始标题字段"的格式对训练集进行提取。可以根据新闻标题直接开展文本分类任务,请问答题者能否提供自己的解决方案?

训练集格式 标签ID+\t+标签+\t+原文标题 测试集格式 原文标题

提交答案

请在参加考试时,请您提供以下材料:模型代码版本号和运行结果报告。其中各字段需严格按照规定格式填写

1.每个类别的行数和测试集原始数据行数应一一对应,不可乱序

2.输出结果应检查是否为83599行数据,否则成绩无效

3.输出结果文件命名为result.txt,一行一个类别,样例如下:

···

游戏

财经

时政

股票

家居

科技

社会

房产

教育

星座

科技

股票

游戏

财经

时政

股票

家居

科技

社会

房产

教育

···

代码思路说明

基于题目可知此为一个经典的NLP任务

  • 数据经过预处理后被转换为词向量表示
  • 模型架构被构建完成
  • 模型通过训练集进行学习
  • 模型从预训练权重加载后,在测试数据上进行推理以输出预测结果

那么话不多说我们开始!

数据集解压

复制代码
    ! pip install -U paddlepaddle==2.0.1
复制代码
    ! unzip -oq /home/aistudio/data/data75812/新闻文本标签分类.zip
复制代码
    import paddle
    import numpy as np
    import matplotlib.pyplot as plt
    import paddle.nn as nn
    import os
    import numpy as np
    
    print(paddle.__version__)  # 查看当前版本
    
    # cpu/gpu环境选择,在 paddle.set_device() 输入对应运行设备。
    # device = paddle.set_device('gpu')
复制代码
    /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
      from collections import MutableMapping
    /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
      from collections import Iterable, Mapping
    /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
      from collections import Sized
    2021-03-27 12:21:25,020 - INFO - font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']
    2021-03-27 12:21:25,357 - INFO - generated new fontManager
    
    
    2.0.1

数据处理

在处理过程中,我们需要关注的是如何表示词语的意义。为了方便后续操作,在这里我们将采用一种更为简洁的方式来进行处理。具体来说,在建立完词典之后我们将直接使用它来进行后续的数据处理工作。通过建立映射关系……生成唯一对应的数值编码。得到对应码以后通过反向验证……确认编码的准确性。在这一阶段如果发现有任何错误将及时纠正并修复后再继续下一步工作。如果所有原始信息都无误则将它们转换为数值形式并将其存储在一个统一格式的数据集中这样既保证了原有信息的完整性又便于后续分析工作。在此过程中我们会使用一些特殊的标记来替代原有的非数值符号从而确保所有处理后的数据具有统一的格式和结构特征这样不仅能够提高工作效率还能保证数据分析结果的高度一致性与准确性最后我们将对整个处理过程进行全面的质量检查以确保所有操作都符合预期要求

数据读取(字典、数据集)

复制代码
    # 字典读取
    def get_dict_len(d_path):
    with open(d_path, 'r', encoding='utf-8') as f:
        line = eval(f.readlines()[0])
    return line
    
    word_dict = get_dict_len('新闻文本标签分类/dict.txt')
复制代码
    # 训练集和验证集读取
    set = []
    def dataset(datapath):  # 数据集读取代码
    with open(datapath)as f:
        for i in f.readlines():
            data = []
            dataset = i[:i.rfind('\t')].split(',')  # 获取文字内容
            dataset = np.array(dataset)
            data.append(dataset)
            label = np.array(i[i.rfind('\t')+1:-1])  # 获取标签
            data.append(label)
            set.append(data)
    return set
    
    train_dataset = dataset('新闻文本标签分类/Train_IDs.txt')
    val_dataset = dataset('新闻文本标签分类/Val_IDs.txt')

数据初始化


定义一些需要的值

复制代码
    # 初始数据准备 
    vocab_size = len(word_dict) + 1  # 字典长度加1
    print(vocab_size)
    emb_size = 256  # 神经网络长度
    seq_len = 30  # 数据集长度(需要扩充的长度)
    batch_size = 32  # 批处理大小
    epochs = 2  # 训练轮数
    pad_id = word_dict['<unk>']  # 空的填充内容值
    
    nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"]
    
    # 生成句子列表(数据码生成文本)
    def ids_to_str(ids):
    # print(ids)
    words = []
    for k in ids:
        w = list(word_dict)[eval(k)]
        words.append(w if isinstance(w, str) else w.decode('ASCII'))
    return " ".join(words)
复制代码
    5308

数据查看


查看数据是否正确如有异常及时修改

复制代码
    # 查看数据内容
    for i in  train_dataset:
    sent = i[0]
    label = int(i[1])
    print('sentence list id is:', sent)  # 数据内容
    print('sentence label id is:', label)  # 对应标签
    print('--------------------------')  # 分隔线
    print('sentence list is: ', ids_to_str(sent))  # 转换后的数据
    print('sentence label is: ', nu[label])  # 转换后的标签
    break
复制代码
    sentence list id is: ['2976' '385' '2050' '3757' '1147' '3296' '1585' '688' '1180' '2608'
     '4280' '1887']
    sentence label id is: 0
    --------------------------
    sentence list is:  上 证 5 0 E T F 净 申 购 突 增
    sentence label is:  财经

数据扩充


把数据扩充成一样的长度

复制代码
    # 数据扩充并查看
    def create_padded_dataset(dataset):
    padded_sents = []
    labels = []
    for batch_id, data in enumerate(dataset):  # 读取数据
        sent, label = data[0], data[1]  # 标签和数据拆分
        padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32')  # 数据拼接
            
        # print(padded_sent)
        padded_sents.append(padded_sent)  # 写入数据
        labels.append(label)  # 写入标签
    # print(padded_sents)
    return np.array(padded_sents), np.array(labels).astype('int64')  # 转换成数组并返回
    
    # 对train、val数据进行实例化
    train_sents, train_labels = create_padded_dataset(train_dataset)  # 实例化训练集
    val_sents, val_labels = create_padded_dataset(val_dataset)  # 实例化测试集
    train_labels = train_labels.reshape(832475,1)  # 标签数据大小转换
    val_labels = val_labels.reshape(832475,1)
    # 查看数据大小及举例内容
    print(train_sents.shape)
    print(train_labels.shape)
    print(val_sents.shape)
    print(val_labels.shape)
复制代码
    (832475, 30)
    (832475, 1)
    (832475, 30)
    (832475, 1)

数据封装


基于paddle.io.Dataset类的继承结构下,在此基础上将数据打包处理后生成可用于训练的数据格式

复制代码
    # 继承paddle.io.Dataset对数据进行处理
    class IMDBDataset(paddle.io.Dataset):
    '''
    继承paddle.io.Dataset类进行封装数据
    '''
    def __init__(self, sents, labels):
        # 数据读取
        self.sents = sents
        self.labels = labels
    
    def __getitem__(self, index):
        # 数据处理
        data = self.sents[index]
        label = self.labels[index]
    
        return data, label
    
    def __len__(self):
        # 返回大小数据
        return len(self.sents)
    
    # 数据实例化 
    train_dataset = IMDBDataset(train_sents, train_labels)
    val_dataset = IMDBDataset(val_sents, val_labels)
    
    # 封装成生成器
    train_loader = paddle.io.DataLoader(train_dataset, return_list=True,
                                    shuffle=True, batch_size=batch_size, drop_last=True)
    val_loader = paddle.io.DataLoader(val_dataset, return_list=True,
                                    shuffle=True, batch_size=batch_size, drop_last=True)
复制代码
    # 查看生成器内的数据内容及大小
    for i in train_loader:
    print(i)
    break
    for j in val_loader:
    print(j)
    break
复制代码
    [Tensor(shape=[32, 30], dtype=int32, place=CPUPlace, stop_gradient=True,
       [[4041, 4370, 3449, 3536, 103 , 2896, 4133, 312 , 1974, 3933, 2380, 805 , 3956, 4805, 3129, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1440, 3740, 1169, 2663, 4401, 4591, 4874, 2734, 989 , 1980, 5016, 450 , 335 , 1562, 2543, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [580 , 3844, 3513, 1231, 4111, 1894, 737 , 1318, 3536, 4805, 3956, 4075, 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2573, 536 , 1230, 3757, 610 , 2018, 1974, 39  , 1629, 121 , 4625, 294 , 450 , 1991, 3149, 4389, 1146, 1736, 588 , 3388, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4829, 4419, 3415, 1230, 4910, 3814, 1876, 3509, 1592, 5059, 2207, 2139, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1546, 1221, 1117, 4386, 3449, 1562, 2088, 4770, 1299, 4500, 41  , 2976, 725 , 1006, 2053, 897 , 2315, 3786, 2559, 828 , 3682, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2185, 4673, 1546, 2991, 1120, 5025, 782 , 5025, 1674, 3717, 1006, 2099, 4807, 78  , 4749, 1932, 5283, 1375, 4725, 3185, 2358, 2100, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2140, 4935, 3388, 278 , 3287, 4059, 775 , 1304, 4315, 698 , 3375, 3966, 3980, 1472, 1472, 2140, 4935, 3388, 5303, 939 , 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [5072, 2886, 4647, 3957, 5276, 2139, 4646, 5053, 4073, 4954, 1006, 4038, 2896, 3886, 756 , 4289, 2700, 4242, 4954, 2018, 2336, 2412, 2764, 4711, 5306, 5306, 5306, 5306, 5306, 5306],
        [1546, 1231, 1230, 385 , 4774, 5269, 939 , 2845, 1147, 2358, 3947, 4774, 872 , 1592, 2896, 123 , 5059, 1177, 3947, 4191, 4841, 754 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1180, 2646, 2155, 2776, 2886, 1257, 2302, 2748, 39  , 1230, 478 , 1006, 1425, 2263, 1278, 5078, 959 , 5102, 4578, 671 , 3430, 4954, 4910, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2891, 5257, 4426, 4932, 189 , 1695, 1347, 1724, 4328, 3344, 1688, 3449, 5115, 379 , 1347, 2244, 5216, 3070, 5072, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [67  , 2788, 2873, 898 , 4207, 1347, 12  , 372 , 1737, 1006, 3468, 383 , 1836, 5115, 4608, 4790, 1620, 760 , 3313, 2244, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2099, 4807, 3379, 200 , 3933, 472 , 4415, 312 , 2078, 3222, 44  , 3222, 3924, 2373, 3398, 643 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2311, 3967, 720 , 2014, 2873, 311 , 4346, 2961, 4401, 725 , 1425, 1006, 1505, 3430, 4647, 926 , 4554, 4702, 4246, 2358, 3115, 5279, 123 , 1230, 679 , 5306, 5306, 5306, 5306, 5306],
        [1521, 2571, 1079, 4554, 1070, 534 , 2088, 2140, 5229, 1425, 3242, 846 , 3933, 3714, 99  , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2916, 123 , 1844, 5059, 123 , 1747, 3040, 1006, 5205, 1688, 1347, 601 , 3041, 3144, 3269, 4059, 2986, 4863, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3888, 2153, 4813, 3053, 1741, 1648, 2757, 1177, 2033, 2991, 5283, 123 , 2779, 2651, 1053, 1522, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4444, 5283, 1138, 3114, 3890, 3489, 1028, 3717, 936 , 389 , 2886, 2031, 316 , 3187, 2031, 2623, 643 , 4911, 3468, 1253, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3740, 2925, 3023, 2851, 4389, 3092, 3576, 725 , 1736, 2300, 3114, 1006, 2122, 1076, 3973, 3092, 3951, 2664, 1059, 3440, 415 , 3099, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1441, 312 , 134 , 4697, 1896, 1449, 3973, 4955, 3449, 1498, 1199, 2032, 2359, 4822, 1006, 4883, 4389, 4038, 4552, 4509, 2347, 690 , 1094, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2777, 422 , 1902, 2428, 621 , 3313, 3973, 5014, 5140, 3086, 4822, 1006, 3809, 3305, 3343, 5161, 1230, 1995, 3684, 954 , 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [813 , 903 , 4554, 3449, 1195, 3790, 4067, 1932, 2347, 3082, 4625, 2061, 3191, 992 , 1006, 1819, 3040, 4650, 1395, 729 , 5125, 5202, 2939, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3952, 3493, 385 , 225 , 3449, 1613, 4822, 3534, 3191, 2896, 3927, 698 , 3375, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1230, 831 , 1347, 2244, 1588, 3813, 2044, 3094, 1076, 4626, 1006, 1231, 1230, 3853, 4366, 2511, 2605, 3726, 5303, 939 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [5052, 2293, 3449, 3446, 1094, 2976, 4922, 2099, 1221, 4034, 1290, 3323, 3430, 3099, 4109, 4579, 1006, 1713, 3058, 4370, 1613, 4191, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1359, 4922, 2748, 3933, 2099, 397 , 2858, 1006, 4438, 221 , 611 , 4159, 2642, 939 , 4784, 664 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4554, 1667, 477 , 2891, 1819, 2354, 1819, 3040, 1006, 2873, 898 , 3740, 1408, 2176, 3371, 123 , 5151, 2886, 3040, 1275, 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4746, 3242, 5010, 3430, 2401, 4426, 4373, 1695, 2776, 775 , 1006, 1502, 3952, 2428, 1935, 3687, 809 , 416 , 1503, 4500, 1854, 2352, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2351, 3287, 3813, 2032, 4554, 1519, 1655, 4038, 3951, 2958, 2886, 2140, 1006, 4246, 4536, 3449, 1476, 2572, 4207, 4401, 1505, 2953, 3468, 377 , 5306, 5306, 5306, 5306, 5306, 5306],
        [3712, 3583, 3973, 2312, 4426, 3305, 2979, 1897, 3513, 4059, 1695, 1006, 5293, 4382, 2199, 1076, 4412, 3559, 1215, 2640, 1343, 4785, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2830, 2567, 1472, 134 , 3040, 1275, 3951, 377 , 420 , 1753, 1598, 690 , 3682, 4500, 1006, 3135, 3853, 4862, 3253, 377 , 2263, 5105, 3060, 5306, 5306, 5306, 5306, 5306, 5306, 5306]]), Tensor(shape=[32, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[3 ],
        [6 ],
        [3 ],
        [5 ],
        [6 ],
        [3 ],
        [6 ],
        [2 ],
        [10],
        [3 ],
        [10],
        [6 ],
        [13],
        [6 ],
        [10],
        [6 ],
        [6 ],
        [4 ],
        [9 ],
        [10],
        [10],
        [13],
        [10],
        [3 ],
        [0 ],
        [3 ],
        [3 ],
        [13],
        [13],
        [10],
        [13],
        [10]])]
    [Tensor(shape=[32, 30], dtype=int32, place=CPUPlace, stop_gradient=True,
       [[2607, 5278, 1979, 2932, 40  , 2813, 2361, 3114, 4111, 3099, 1221, 103 , 2079, 3951, 2050, 3757, 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2050, 2751, 3403, 1214, 516 , 1006, 4059, 2125, 2380, 233 , 1521, 805 , 366 , 2336, 2176, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2312, 487 , 2185, 4832, 4426, 2099, 1811, 1695, 1413, 4813, 3053, 3222, 4523, 3820, 2143, 1020, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3740, 4334, 377 , 1299, 4062, 4442, 536 , 3487, 3398, 4863, 1850, 4480, 1006, 2896, 4673, 2776, 1230, 3114, 3786, 4442, 3507, 1902, 2428, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1445, 4075, 1006, 610 , 805 , 3757, 3634, 2453, 1521, 736 , 1661, 1394, 4874, 3822, 1006, 421 , 3424, 3296, 610 , 610 , 3757, 316 , 4863, 3702, 2192, 5306, 5306, 5306, 5306, 5306],
        [2185, 2685, 4863, 5257, 3430, 2813, 2233, 684 , 846 , 892 , 1006, 3593, 3966, 3951, 4343, 2079, 892 , 4352, 4242, 3091, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3131, 1809, 2052, 4359, 3449, 1199, 2401, 1441, 2768, 4073, 1724, 4191, 1301, 3956, 3757, 2050, 2751, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [5032, 410 , 4835, 3449, 2099, 44  , 989 , 4073, 1724, 4191, 1521, 1521, 3642, 2751, 1006, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2050, 3757, 3880, 4945, 2515, 1112, 4224, 1282, 3379, 4477, 834 , 2013, 4874, 3823, 617 , 1090, 4060, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3886, 2843, 2412, 1722, 1230, 3092, 4197, 1006, 699 , 1839, 380 , 1834, 1521, 3757, 1631, 4237, 518 , 3813, 2768, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1440, 3740, 3520, 2832, 3888, 2886, 1993, 3952, 2427, 1215, 2550, 4248, 4328, 4099, 5103, 2337, 3468, 4456, 3191, 4062, 5072, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [802 , 97  , 1876, 2768, 4191, 4785, 1318, 3991, 1006, 3165, 4191, 3509, 1318, 4504, 736 , 3757, 3757, 3757, 3757, 3375, 5019, 4959, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2233, 4370, 1347, 726 , 3886, 3142, 3259, 260 , 1445, 746 , 3238, 1025, 332 , 993 , 1006, 1301, 1661, 2845, 1836, 5115, 3738, 2199, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [5267, 4953, 1472, 1876, 2873, 3951, 377 , 4841, 754 , 125 , 3224, 1006, 3951, 3967, 2983, 2886, 4038, 5135, 684 , 123 , 1521, 1301, 2846, 389 , 4841, 5306, 5306, 5306, 5306, 5306],
        [1993, 1837, 5281, 3992, 1425, 3740, 224 , 804 , 3534, 3191, 2099, 4807, 3735, 5067, 1006, 4449, 2375, 2375, 4945, 2515, 2436, 1253, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [224 , 4382, 3379, 200 , 3449, 1230, 3996, 805 , 141 , 1006, 3379, 200 , 4576, 2680, 3430, 3042, 1081, 3537, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2820, 682 , 2825, 2759, 1230, 294 , 4389, 3069, 3355, 2896, 1215, 2825, 4222, 3244, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4274, 3740, 3413, 3449, 134 , 377 , 603 , 3886, 2873, 123 , 4289, 3020, 1230, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [1293, 3714, 40  , 1472, 1094, 1440, 1669, 3966, 2756, 4432, 1521, 4591, 1094, 4591, 2052, 1006, 4951, 1418, 4019, 1425, 3740, 4775, 1839, 3430, 738 , 5306, 5306, 5306, 5306, 5306],
        [1164, 2453, 1185, 4162, 3430, 1546, 3740, 3398, 2052, 3559, 1221, 2050, 3956, 3757, 805 , 141 , 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2312, 536 , 1747, 3164, 2986, 542 , 3023, 3907, 1006, 4456, 4009, 3296, 3634, 1521, 3757, 5059, 736 , 736 , 3757, 3757, 2751, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [3142, 4886, 3430, 4954, 5177, 4242, 4382, 3952, 4931, 795 , 1006, 2099, 2886, 4651, 1562, 2986, 2155, 1521, 4591, 3966, 601 , 3041, 2151, 377 , 5306, 5306, 5306, 5306, 5306, 5306],
        [3400, 872 , 1893, 3016, 3933, 2263, 2781, 3114, 692 , 3222, 1620, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2176, 2830, 3449, 3668, 3131, 3402, 2727, 224 , 264 , 4370, 4389, 1318, 1641, 2932, 1940, 4805, 2886, 4207, 4225, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2549, 3101, 2099, 690 , 4111, 3682, 3537, 534 , 4167, 3137, 4954, 1006, 1785, 3869, 823 , 3924, 3473, 3881, 927 , 730 , 592 , 476 , 3207, 241 , 5306, 5306, 5306, 5306, 5306, 5306],
        [4227, 1562, 4027, 4954, 1521, 610 , 375 , 3889, 2896, 2239, 4370, 4141, 3000, 56  , 1006, 4697, 200 , 269 , 926 , 1413, 4540, 5238, 1017, 3468, 2014, 964 , 5306, 5306, 5306, 5306],
        [2358, 1025, 1708, 993 , 332 , 1006, 1862, 1006, 2358, 1025, 3956, 2079, 1709, 720 , 3676, 4050, 3357, 1472, 2941, 2254, 2412, 1029, 3222, 1725, 1028, 3165, 5306, 5306, 5306, 5306],
        [2820, 682 , 3191, 3440, 1146, 3174, 4328, 2982, 2825, 2759, 1117, 3069, 3355, 617 , 2813, 2742, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [2612, 3069, 1214, 3951, 1521, 3956, 4805, 3115, 1314, 2050, 3757, 3757, 366 , 1006, 1534, 2401, 5202, 1521, 4805, 366 , 3157, 2336, 2830, 4711, 5306, 5306, 5306, 5306, 5306, 5306],
        [4057, 62  , 1765, 531 , 1991, 3149, 5269, 736 , 3757, 1521, 736 , 1991, 3149, 4389, 2018, 4389, 2253, 1694, 4073, 1200, 5116, 4073, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4746, 4283, 5295, 3449, 2099, 1794, 376 , 2040, 1663, 3564, 3187, 2986, 1006, 4111, 2896, 690 , 1117, 2776, 4500, 2078, 2040, 698 , 1214, 5306, 5306, 5306, 5306, 5306, 5306, 5306],
        [4289, 2166, 698 , 2100, 1006, 1343, 1681, 1094, 4863, 123 , 5162, 384 , 61  , 2380, 1645, 3388, 2336, 736 , 4711, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306, 5306]]), Tensor(shape=[32, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[6 ],
        [8 ],
        [4 ],
        [10],
        [6 ],
        [5 ],
        [3 ],
        [3 ],
        [6 ],
        [6 ],
        [13],
        [0 ],
        [12],
        [10],
        [13],
        [3 ],
        [7 ],
        [9 ],
        [10],
        [3 ],
        [6 ],
        [10],
        [4 ],
        [9 ],
        [10],
        [10],
        [3 ],
        [7 ],
        [7 ],
        [5 ],
        [10],
        [8 ]])]

网络定义


定义网络情况,用于训练,这一块是提高成绩的关键之一

复制代码
    # 定义网络
    class MyNet(paddle.nn.Layer):
    def __init__(self):
        super(MyNet, self).__init__() 
        self.emb = paddle.nn.Embedding(vocab_size, emb_size)  # 嵌入层用于自动构造一个二维embedding矩阵
        self.fc = paddle.nn.Linear(in_features=emb_size, out_features=96)  # 线性变换层 
        self.fc1 = paddle.nn.Linear(in_features=96, out_features=14)  # 分类器
        self.dropout = paddle.nn.Dropout(0.5)  # 正则化
    
    def forward(self, x):
        x = self.emb(x)
        x = paddle.mean(x, axis=1)  # 获取平均值
        x = self.dropout(x)
        x = self.fc(x)
        x = self.dropout(x)
        x = self.fc1(x)
        return x
复制代码
    # 画图
    def draw_process(title,color,iters,data,label):
    plt.title(title, fontsize=24)  # 标题
    plt.xlabel("iter", fontsize=20)  # x轴
    plt.ylabel(label, fontsize=20)  # y轴
    plt.plot(iters, data,color=color,label=label)   # 画图
    plt.legend()
    plt.grid()
    plt.show()

模型训练


训练的重要环节,可以调节学习率,优化器等,有可能有奇效

复制代码
    # 训练模型
    def train(model):
    model.train()
    opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())  # 优化器学习率等
    # 初始值设置
    steps = 0
    Iters, total_loss, total_acc = [], [], []
    
    for epoch in range(epochs):  # 训练循环
        for batch_id, data in enumerate(train_loader):  # 数据循环
            steps += 1
            sent = data[0]  # 获取数据
            label = data[1]  # 获取标签
            
            logits = model(sent)  # 输入数据
            loss = paddle.nn.functional.cross_entropy(logits, label)  # loss获取
            acc = paddle.metric.accuracy(logits, label)  # acc获取
    
            if batch_id % 500 == 0:  # 每500次输出一次结果
                Iters.append(steps)  # 保存训练轮数
                total_loss.append(loss.numpy()[0])  # 保存loss
                total_acc.append(acc.numpy()[0])  # 保存acc
    
                print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))  # 输出结果
            
            # 数据更新
            loss.backward()  
            opt.step()  
            opt.clear_grad()  
    
        # 每一个epochs进行一次评估
        model.eval()
        accuracies = []
        losses = []
        
        for batch_id, data in enumerate(val_loader):  # 数据循环读取
            
            sent = data[0]  # 训练内容读取
            label = data[1]  # 标签读取
    
            logits = model(sent)  # 训练数据
            loss = paddle.nn.functional.cross_entropy(logits, label)  # loss获取
            acc = paddle.metric.accuracy(logits, label)  # acc获取
            
            accuracies.append(acc.numpy())  # 添加数据
            losses.append(loss.numpy())  
        
        avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)  # 获取loss、acc平均值
        print("[validation] accuracy: {}, loss: {}".format(avg_acc, avg_loss))  # 输出值
        
        model.train()
    
        paddle.save(model.state_dict(),str(epoch)+"_model_final.pdparams")  # 保存训练文件
    
    draw_process("trainning loss","red",Iters,total_loss,"trainning loss")  # 画处loss图
    draw_process("trainning acc","green",Iters,total_acc,"trainning acc")  # 画出caa图
        
    model = MyNet()  # 模型实例化
    train(model)  # 开始训练
复制代码
    epoch: 0, batch_id: 0, loss is: [2.6477456]
    epoch: 0, batch_id: 500, loss is: [1.8056118]
    epoch: 0, batch_id: 1000, loss is: [1.1092072]
    epoch: 0, batch_id: 1500, loss is: [1.0716103]
    epoch: 0, batch_id: 2000, loss is: [0.6794955]
    epoch: 0, batch_id: 2500, loss is: [0.54738545]
    epoch: 0, batch_id: 3000, loss is: [0.9065808]
    epoch: 0, batch_id: 3500, loss is: [0.63474274]
    epoch: 0, batch_id: 4000, loss is: [0.68158776]
    epoch: 0, batch_id: 4500, loss is: [1.0516238]
    epoch: 0, batch_id: 5000, loss is: [0.9118046]
    epoch: 0, batch_id: 5500, loss is: [0.65075576]
    epoch: 0, batch_id: 6000, loss is: [0.5605841]
    epoch: 0, batch_id: 6500, loss is: [0.56175774]
    epoch: 0, batch_id: 7000, loss is: [0.95122683]
    epoch: 0, batch_id: 7500, loss is: [0.38649452]
    epoch: 0, batch_id: 8000, loss is: [0.2205698]
    epoch: 0, batch_id: 8500, loss is: [0.40474647]
    epoch: 0, batch_id: 9000, loss is: [0.5931748]
    epoch: 0, batch_id: 9500, loss is: [0.3922717]
    epoch: 0, batch_id: 10000, loss is: [0.6130478]
    epoch: 0, batch_id: 10500, loss is: [0.5300909]
    epoch: 0, batch_id: 11000, loss is: [0.6114788]
    epoch: 0, batch_id: 11500, loss is: [0.24966809]
    epoch: 0, batch_id: 12000, loss is: [0.45669073]
    epoch: 0, batch_id: 12500, loss is: [0.29746443]
    epoch: 0, batch_id: 13000, loss is: [0.6775298]
    epoch: 0, batch_id: 13500, loss is: [0.8836371]
    epoch: 0, batch_id: 14000, loss is: [0.27501673]
    epoch: 0, batch_id: 14500, loss is: [0.46843478]
    epoch: 0, batch_id: 15000, loss is: [0.49367175]
    epoch: 0, batch_id: 15500, loss is: [0.500063]
    epoch: 0, batch_id: 16000, loss is: [0.31290954]
    epoch: 0, batch_id: 16500, loss is: [0.30774388]
    epoch: 0, batch_id: 17000, loss is: [0.21738727]
    epoch: 0, batch_id: 17500, loss is: [0.2860858]
    epoch: 0, batch_id: 18000, loss is: [0.2766972]
    epoch: 0, batch_id: 18500, loss is: [0.36017033]
    epoch: 0, batch_id: 19000, loss is: [0.43986273]
    epoch: 0, batch_id: 19500, loss is: [0.4210134]
    epoch: 0, batch_id: 20000, loss is: [0.579644]
    epoch: 0, batch_id: 20500, loss is: [0.23016676]
    epoch: 0, batch_id: 21000, loss is: [0.21913218]
    epoch: 0, batch_id: 21500, loss is: [0.18669227]
    epoch: 0, batch_id: 22000, loss is: [0.31480896]
    epoch: 0, batch_id: 22500, loss is: [0.37621552]
    epoch: 0, batch_id: 23000, loss is: [0.54980826]
    epoch: 0, batch_id: 23500, loss is: [0.6016808]
    epoch: 0, batch_id: 24000, loss is: [0.25056183]
    epoch: 0, batch_id: 24500, loss is: [0.2916811]
    epoch: 0, batch_id: 25000, loss is: [0.33430776]
    epoch: 0, batch_id: 25500, loss is: [0.74600095]
    epoch: 0, batch_id: 26000, loss is: [0.35165167]
    [validation] accuracy: 0.884321928024292, loss: 0.3713749647140503
    epoch: 1, batch_id: 0, loss is: [0.47405708]
    epoch: 1, batch_id: 500, loss is: [0.4443894]
    epoch: 1, batch_id: 1000, loss is: [0.35416052]
    epoch: 1, batch_id: 1500, loss is: [0.3004715]
    epoch: 1, batch_id: 2000, loss is: [0.59477925]
    epoch: 1, batch_id: 2500, loss is: [0.5639044]
    epoch: 1, batch_id: 3000, loss is: [0.40286714]
    epoch: 1, batch_id: 3500, loss is: [0.5387965]
    epoch: 1, batch_id: 4000, loss is: [0.11766122]
    epoch: 1, batch_id: 4500, loss is: [0.68849707]
    epoch: 1, batch_id: 5000, loss is: [0.83928466]
    epoch: 1, batch_id: 5500, loss is: [0.2867105]
    epoch: 1, batch_id: 6000, loss is: [0.20924558]
    epoch: 1, batch_id: 6500, loss is: [0.5582311]
    epoch: 1, batch_id: 7000, loss is: [0.63174886]
    epoch: 1, batch_id: 7500, loss is: [0.318484]
    epoch: 1, batch_id: 8000, loss is: [0.5406461]
    epoch: 1, batch_id: 8500, loss is: [0.4790561]
    epoch: 1, batch_id: 9000, loss is: [0.52266514]
    epoch: 1, batch_id: 9500, loss is: [0.51126254]
    epoch: 1, batch_id: 10000, loss is: [0.27308795]
    epoch: 1, batch_id: 10500, loss is: [0.22041513]
    epoch: 1, batch_id: 11000, loss is: [0.32234907]
    epoch: 1, batch_id: 11500, loss is: [0.6857507]
    epoch: 1, batch_id: 12000, loss is: [0.40997463]
    epoch: 1, batch_id: 12500, loss is: [0.53966033]
    epoch: 1, batch_id: 13000, loss is: [0.2620927]
    epoch: 1, batch_id: 13500, loss is: [0.21417136]
    epoch: 1, batch_id: 14000, loss is: [0.5232475]
    epoch: 1, batch_id: 14500, loss is: [0.37579858]
    epoch: 1, batch_id: 15000, loss is: [0.3611152]
    epoch: 1, batch_id: 15500, loss is: [0.336707]
    epoch: 1, batch_id: 16000, loss is: [0.2795578]
    epoch: 1, batch_id: 16500, loss is: [0.54298353]
    epoch: 1, batch_id: 17000, loss is: [0.26425135]
    epoch: 1, batch_id: 17500, loss is: [0.52595145]
    epoch: 1, batch_id: 18000, loss is: [0.24938256]
    epoch: 1, batch_id: 18500, loss is: [0.30653632]
    epoch: 1, batch_id: 19000, loss is: [0.58400965]
    epoch: 1, batch_id: 19500, loss is: [0.18243803]
    epoch: 1, batch_id: 20000, loss is: [0.28917578]
    epoch: 1, batch_id: 20500, loss is: [1.0765818]
    epoch: 1, batch_id: 21000, loss is: [0.32550114]
    epoch: 1, batch_id: 21500, loss is: [0.16792971]
    epoch: 1, batch_id: 22000, loss is: [0.65214527]
    epoch: 1, batch_id: 22500, loss is: [0.58119446]
    epoch: 1, batch_id: 23000, loss is: [0.43643892]
    epoch: 1, batch_id: 23500, loss is: [0.47376677]
    epoch: 1, batch_id: 24000, loss is: [0.3279624]
    epoch: 1, batch_id: 24500, loss is: [0.50899947]
    epoch: 1, batch_id: 25000, loss is: [0.61989105]
    epoch: 1, batch_id: 25500, loss is: [0.42433214]
    epoch: 1, batch_id: 26000, loss is: [0.26673254]
    [validation] accuracy: 0.8882260322570801, loss: 0.35311153531074524
    
    
    /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
      if isinstance(obj, collections.Iterator):
    /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
      return list(data) if isinstance(data, collections.MappingView) else data

外部链接中的图片未能成功上传至该平台(img-YjXXeU9C-1618337714702)(output_26_2.png))。请参考此建议:可建议用户将图片保存后再直接上传至目标平台以解决当前问题

外部链路中的图片无法正常上传,请确保您已将所需图像保存后再尝试直接上传(img-Xub3O8RP-1618337714703)(output_26_3.png)

推理数据读取

复制代码
    # 比赛数据读取
    set = []
    def dataset(datapath):
    with open(datapath)as f:  # 读取文件
        for i in f.readlines():  # 逐行读取数据
            dataset = np.array(i.split(','))  # 分割数据
            set.append(dataset)  # 存入数据
    return set
    
    # 比赛数据扩充
    def create_padded_dataset(dataset):
    padded_sents = []
    labels = []
    for batch_id, data in enumerate(dataset):  # 循环
        # print(data)
        sent = data  # 读取数据
        padded_sent = np.concatenate([sent[:seq_len], [pad_id] * (seq_len - len(sent))]).astype('int32')  # 拼接填充
            
        # print(padded_sent)
        padded_sents.append(padded_sent)  # 输入数据
    
    # print(padded_sents)
    return np.array(padded_sents)  # 转换成数组并返回
    
    test_data = dataset('新闻文本标签分类/Test_IDs.txt')  # 读取数据
    # print()
    # 对train、val数据进行实例化
    test_data = create_padded_dataset(test_data)  # 数据填充
    
    # 查看数据大小及举例内容
    print(test_data)
复制代码
    [[4057 1902 1475 ... 5306 5306 5306]
     [2805 5242 3593 ... 5306 5306 5306]
     [1836 3222 4641 ... 5306 5306 5306]
     ...
     [4838 1202 1490 ... 5306 5306 5306]
     [ 805 3757 3757 ... 5306 5306 5306]
     [2805 5242 3593 ... 5306 5306 5306]]

开始推理


这里可以选择效果好的模型然后进行预测

复制代码
    nu=["财经","彩票","房产","股票","家居","教育","科技","社会","时尚","时政","体育","星座","游戏","娱乐"]  # 标签列表
    
    # 导入模型
    model_state_dict = paddle.load('0_model_final.pdparams')  # 模型读取
    model = MyNet()  # 读取网络
    model.set_state_dict(model_state_dict)  
    model.eval()
    # print(type(test_data[0]))
    count = 0  # 初始值
    with open('./result.txt', 'w', encoding='utf-8') as f_train:  # 生成文件
    for batch_id, data in enumerate(test_data):  # 循环数据
        results = model(paddle.to_tensor(data.reshape(30,1)))  # 开始训练
    
        for probs in results:
            # 映射分类label
            idx = np.argmax(probs)  # 获取结果值
            labels = nu[idx]  # 通过结果值获取标签
            f_train.write(labels+"\n")  # 写入数据
            count +=1
            break
            
        if count%500==0:  # 查看推理情况
            print(count)
            
    
    print(count)

未必达到预期效果, 但目前初步方案仍可实施。如有其他需求, 欢迎随时联系本人, 包括通过留言区或群聊等方式进行交流。再次感谢您的支持与关注。

作者简介

个人简介:三岁

经历:自主学习Python技术后现活跃于Paddle社区致力于帮助新手逐步掌握相关技术

地址:

我在AI Studio上兑换钻石等级并累计获得7枚徽章快来关注我吧!https://aistudio.baidu.com/aistudio/personalcenter/thirdview/284366

飞桨社区的顶尖代码选手们共同携手并进!
请注意:所有作品均需符合"三岁标准"(特别提醒:此类调侃类内容已超出讨论范围)

全部评论 (0)

还没有任何评论哟~