使用PyTorch处理AG_NEWS新闻分类数据集
发布时间
阅读量:
阅读量
文章目录
如何利用PyTorch对AG_NEWS新闻分类数据集进行处理?本指南将详细介绍以下步骤:包括数据加载操作、文本分词过程、词汇表构建步骤以及预处理流水线的建立。
1. 数据加载及查看
利用PyTorch框架实现对AG-NEWS新闻分类数据集进行处理的主要步骤包括:1)数据加载步骤;2)文本分词流程;3)词汇表构建过程;4)预处理流水线设计。这些环节共同构成了完整的模型训练体系。
1. 数据加载与查看
from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(root='../datasets', split='train')
print("连续三个next(train_iter)得到的结果:")
print(next(train_iter))
print(next(train_iter))
print(next(train_iter))
- 功能模块 :导入AG_NEWS训练数据集,并输出前三个样本实例。
- 输出示例 :每个样本以元组形式呈现,请看以下具体实例:
(3, "Wall St. Bears Claw Back Into the Black...") (1, "Raging Storms Over The Pacific...") (2, "Baseball World Series 2023 Highlights...") - 注意说明 :AG_NEWS分类标记范围在1至4之间,请具体分为以下几大类:
- 1: World(全球)
- 2: Sports(体育)
- 3: Business(商业)
- 4: Sci/Tec(科技与科学)
- 输出示例 :每个样本以元组形式呈现,请看以下具体实例:
2. 分词器与词汇表构建
tokenizer = get_tokenizer('basic_english') # 基础英文分词器(小写+按空格分割)
train_iter = AG_NEWS(root='../datasets', split='train') # 重新加载迭代器
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text) # 生成分词后的列表
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 未登录词映射到<unk>
- 分词器 :对输入文本执行小写处理并进行空格分割(例如
"Hello World"将被转换为["hello", "world"])。- 词汇表 :基于训练数据生成的所有分词结果构建,并在遇到未知词汇时通过
<unk>标识。 - 注意 :在处理过程中重新加载
train_iter可避免之前打印样本时的数据消耗。
- 词汇表 :基于训练数据生成的所有分词结果构建,并在遇到未知词汇时通过
3. 词汇表测试
print("vocab('Mary Had a Little Lamb'.lower().split())")
print(vocab(['mary', 'had', 'a', 'little', 'lamb'])) # 手动分词结果
print("vocab(tokenizer('Mary Had a Little Lamb'.lower()))")
print(vocab(tokenizer('mary had a little lamb'))) # 分词器处理后的结果
- 输出 :请输出每个词在词汇表中对应的索引列表。
- 作用 :用于验证分词和词汇表的工作是否正常进行。
4. 预处理流水线
def text_pipeline(x):
return vocab(tokenizer(x)) # 文本→分词→索引列表
def label_pipeline(c):
return int(c) - 1 # 标签1~4 → 0~3(适应模型输出)
- 流水线任务:将原始文本转换为模型能够识别和处理的具体索引序列。
- 标签处理任务:将标签调整为从零开始编码(PyTorch模型通常要求类别标签采用0到N-1的连续整数编码)。
5. 预处理测试
print("text_pipeline('Mary Had a Little Lamb'.lower())")
print(text_pipeline('mary had a little lamb')) # 输出索引列表
print("label_pipeline('4')")
print(label_pipeline('4')) # 输出3(对应Sci/Tec)
- 验证结果 :确保文本转换和标签调整符合预期。
潜在问题与改进
- 迭代器重启:在连续调用
train_iter时,请确保始终重新加载数据集以避免信息丢失(代码已正确执行该操作)。 - 类别标记形式:假设数据集中类别标记以字符串形式表示(如 '3' 或 'three'),则需将其转换为整数形式;若实际标记已经是整数,则应调整
label_pipeline进行相应的转换处理。 - 分词性能提升:目前项目中采用的是
basic_english分词器这一较为简单的模型,在后续开发中建议采用BERT等先进的分词模型以提高处理效果。 - 文本序列长度一致性问题:当前开发过程中未涉及相关处理步骤(建议在训练前补充代码实现),但实际训练时需要注意这一问题并采取适当措施进行解决。
全部评论 (0)
还没有任何评论哟~
