Advertisement

pytorch交叉熵损失函数

阅读量:

nn.CrossEntropyLoss 是 PyTorch 中常用的损失函数,适用于分类任务。它结合了 nn.LogSoftmax 和 nn.NLLLoss 的功能,直接处理未经过 softmax 的 logits 输出,计算预测值与真实标签之间的交叉熵损失。该函数通过 weight 参数调整类别不均衡,ignore_index 参数忽略某个类别,reduction 参数指定损失操作。适用于多分类问题、类别不均衡场景,并支持调整权重和忽略类别。代码示例展示了基本用法、使用权重调整和忽略某个类别的情况。

nn.CrossEntropyLoss 是 PyTorch 中广泛应用的损失函数,尤其在分类任务中被广泛应用。它整合了 nn.LogSoftmaxnn.NLLLoss(负对数似然损失)的功能,能够直接接受未经过 softmax 处理的 logits 输出,从而有效计算预测值与真实标签之间的交叉熵损失。

1. 交叉熵损失的原理

交叉熵损失是一种衡量两个概率分布之间差异的指标。在分类任务中,模型输出的logits经softmax函数转换为概率分布,随后与真实标签的概率分布进行对比。通过最小化交叉熵损失,模型得以优化其输出概率分布,使其更贴近真实标签的概率分布。

对于一个类别标签 y,预测概率 p(y),交叉熵损失定义为:

在多分类任务中,当真实标签为 y,预测的对数几率为 z_i 时,交叉熵损失的计算式为:

相比传统的分类方法,该模型在识别任务中的性能表现更为出色,其核心原因在于其简洁的设计和高效的计算效率。

2. nn.CrossEntropyLoss 的参数

weight(可选):一个一维张量,用于为每个类别设置权重,以在类别分布不均衡时调整损失函数。
size_average(可选):若设为True,则损失函数取平均值;否则,对所有样本求和。此参数已弃用,建议使用reduction参数。
ignore_index(可选):指定不参与计算的标签索引。
reduction(可选):指定损失操作方式,可选值为none(无操作)、mean(取平均)或sum(求和)。

3. nn.CrossEntropyLoss 的使用场景

nn.CrossEntropyLoss 适用于以下场景:

在多分类问题中,模型通过输出 logits 进行分类。
在 Softmax 层之前进行损失计算,适用于输出未经过 Softmax 的 logits。
当类别分布不均衡时,可通过调整 weight 参数来平衡各类的损失权重。

4. 代码示例

基本用法
复制代码
 import torch

    
 import torch.nn as nn
    
  
    
 # 创建预测值 logits 和真实标签
    
 logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])  # 未经过 softmax 的分数
    
 labels = torch.tensor([0, 1])  # 类别标签
    
  
    
 # 定义交叉熵损失函数
    
 criterion = nn.CrossEntropyLoss()
    
  
    
 # 计算损失
    
 loss = criterion(logits, labels)
    
 print(loss.item())  # 输出损失值

解释:

logits 是模型输出的非归一化的分数。labels 是真实对应的类别标签。nn.CrossEntropyLoss() 直接接收 logits 作为输入,并在其内部执行 softmaxlog 操作,最终计算交叉熵损失。

使用权重调整类别不均衡
复制代码
 import torch

    
 import torch.nn as nn
    
  
    
 # 创建预测值 logits 和真实标签
    
 logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])
    
 labels = torch.tensor([0, 1])
    
  
    
 # 定义类别权重
    
 weight = torch.tensor([0.7, 0.2, 0.1])  # 每个类别的权重
    
  
    
 # 定义交叉熵损失函数
    
 criterion = nn.CrossEntropyLoss(weight=weight)
    
  
    
 # 计算损失
    
 loss = criterion(logits, labels)
    
 print(loss.item())  # 输出损失值

解释:

权重参数为各个类别分配了不同的权重,用于调整损失计算,从而使得某些类别在损失计算中更加突出。

忽略某个类别的损失
复制代码
 import torch

    
 import torch.nn as nn
    
  
    
 # 创建预测值 logits 和真实标签
    
 logits = torch.tensor([[2.0, 1.0, 0.1], [0.5, 2.5, 0.3]])
    
 labels = torch.tensor([0, -1])  # 第二个样本忽略计算
    
  
    
 # 定义交叉熵损失函数,忽略类别标签为 -1 的样本
    
 criterion = nn.CrossEntropyLoss(ignore_index=-1)
    
  
    
 # 计算损失
    
 loss = criterion(logits, labels)
    
 print(loss.item())  # 输出损失值

解释:

ignore_index参数支持跳过某个特定类别的样本,不影响其在损失计算中的贡献。

5. 小结

nn.CrossEntropyLoss经过深入研究,展现出了卓越的功能特性,特别适用于分类任务。该损失函数通过巧妙地结合 softmax 和对数计算,实现了高效的交叉熵计算,从而显著简化了代码结构,提升了代码的可读性和维护性。通过动态调整权重参数以及忽略某些类别,该损失函数能够灵活适应多种实际应用场景,展现出极强的适应性和灵活性。

参考

该损失函数在分类任务中表现出色,特别适用于处理离散标签的场景。该损失函数在分类任务中表现出色,特别适用于处理离散标签的场景。

全部评论 (0)

还没有任何评论哟~