Advertisement

使用PyTorch处理AG_NEWS新闻分类数据集

阅读量:

文章目录

如何利用PyTorch对AG_NEWS新闻分类数据集进行处理?本指南将详细介绍以下步骤:包括数据加载操作、文本分词过程、词汇表构建步骤以及预处理流水线的建立。

1. 数据加载及查看

利用PyTorch框架实现对AG-NEWS新闻分类数据集进行处理的主要步骤包括:1)数据加载步骤;2)文本分词流程;3)词汇表构建过程;4)预处理流水线设计。这些环节共同构成了完整的模型训练体系。

1. 数据加载与查看

复制代码
    from torchtext.datasets import AG_NEWS
    train_iter = AG_NEWS(root='../datasets', split='train')
    print("连续三个next(train_iter)得到的结果:")
    print(next(train_iter))
    print(next(train_iter))
    print(next(train_iter))
  • 功能模块 :导入AG_NEWS训练数据集,并输出前三个样本实例。
    • 输出示例 :每个样本以元组形式呈现,请看以下具体实例:
      复制代码
      (3, "Wall St. Bears Claw Back Into the Black...")
      (1, "Raging Storms Over The Pacific...")
      (2, "Baseball World Series 2023 Highlights...")
    • 注意说明 :AG_NEWS分类标记范围在1至4之间,请具体分为以下几大类:
      • 1: World(全球)
      • 2: Sports(体育)
      • 3: Business(商业)
      • 4: Sci/Tec(科技与科学)

2. 分词器与词汇表构建

复制代码
    tokenizer = get_tokenizer('basic_english')  # 基础英文分词器(小写+按空格分割)
    train_iter = AG_NEWS(root='../datasets', split='train')  # 重新加载迭代器
    
    def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)  # 生成分词后的列表
    
    vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
    vocab.set_default_index(vocab["<unk>"])  # 未登录词映射到<unk>
  • 分词器 :对输入文本执行小写处理并进行空格分割(例如 "Hello World" 将被转换为 ["hello", "world"])。
    • 词汇表 :基于训练数据生成的所有分词结果构建,并在遇到未知词汇时通过 <unk> 标识。
    • 注意 :在处理过程中重新加载 train_iter 可避免之前打印样本时的数据消耗。

3. 词汇表测试

复制代码
    print("vocab('Mary Had a Little Lamb'.lower().split())")
    print(vocab(['mary', 'had', 'a', 'little', 'lamb']))  # 手动分词结果
    
    print("vocab(tokenizer('Mary Had a Little Lamb'.lower()))")
    print(vocab(tokenizer('mary had a little lamb')))     # 分词器处理后的结果
  • 输出 :请输出每个词在词汇表中对应的索引列表。
    • 作用 :用于验证分词和词汇表的工作是否正常进行。

4. 预处理流水线

复制代码
    def text_pipeline(x):
    return vocab(tokenizer(x))  # 文本→分词→索引列表
    
    def label_pipeline(c):
    return int(c) - 1  # 标签1~4 → 0~3(适应模型输出)
  • 流水线任务:将原始文本转换为模型能够识别和处理的具体索引序列。
    • 标签处理任务:将标签调整为从零开始编码(PyTorch模型通常要求类别标签采用0到N-1的连续整数编码)。

5. 预处理测试

复制代码
    print("text_pipeline('Mary Had a Little Lamb'.lower())")
    print(text_pipeline('mary had a little lamb'))  # 输出索引列表
    
    print("label_pipeline('4')")
    print(label_pipeline('4'))  # 输出3(对应Sci/Tec)
  • 验证结果 :确保文本转换和标签调整符合预期。

潜在问题与改进

  1. 迭代器重启:在连续调用 train_iter 时,请确保始终重新加载数据集以避免信息丢失(代码已正确执行该操作)。
  2. 类别标记形式:假设数据集中类别标记以字符串形式表示(如 '3' 或 'three'),则需将其转换为整数形式;若实际标记已经是整数,则应调整 label_pipeline 进行相应的转换处理。
  3. 分词性能提升:目前项目中采用的是 basic_english 分词器这一较为简单的模型,在后续开发中建议采用BERT等先进的分词模型以提高处理效果。
  4. 文本序列长度一致性问题:当前开发过程中未涉及相关处理步骤(建议在训练前补充代码实现),但实际训练时需要注意这一问题并采取适当措施进行解决。

全部评论 (0)

还没有任何评论哟~