tensorflow交叉熵损失函数
发布时间
阅读量:
阅读量
简介
tensorflow 图像分类问题中,常用的交叉熵损失计算有三个函数:
tf.nn.softmax_cross_entropy_with_logits: 该功能模块用于计算交叉熵损失。tf.nn.softmax_cross_entropy_with_logits_v2: 此功能模块执行交叉熵损失计算,并支持版本号为2的功能增强。tf.nn.sparse_softmax_cross_entropy_with_logits: 该组件专门用于处理稀疏标签的交叉熵损失计算。
简单记录一下这三个函数的用法,免得自己后面再忘记了。
1. softmax_cross_entropy_with_logits
该函数的声明如下:
def softmax_cross_entropy_with_logits(
_sentinel=None, # pylint: disable=invalid-name
labels=None,
logits=None,
dim=-1,
name=None):
需要注意的几点:
- 该函数已不再适用,请使用
tf.nn.sparse_softmax_cross_entropy_with_logits_v2版本(V2)。其改进之处在于,在反向传播过程中不仅更新logits相关的参数,同时也更新labels相关的参数(V2版本会在计算过程中同步考虑这两者)。这一设计源于labels往往来源于机器学习模型的预测结果而非人工标注(尤其是在数据量巨大的情况下),因此手动标注会耗费大量的人力资源。 _sentinel的存在主要目的是确保在调用此函数时明确区分并指明哪些是labels参数、哪些是logits参数(例如调用方式应为:tf.nn.softmax_cross_entropy_with_logits(labels=..., logits=...))logits参数代表的是网络输出前未经Softmax处理的结果(因为此函数内部会自动执行一次Softmax操作)。labels与"logits"两个张量必须具有相同的形状,在一般情况下均为[batchsize, numberClass]labels张量中的每一行都需要满足概率分布的要求(即各行元素之和等于1),而"logits"则无需满足此条件。dim参数主要应用于高维场景下指定哪一维代表分类结果维度(通常情况下,“labels”与"logits"张量的维度均为[batchsize, numberClass], 因此该参数一般无需特别关注
2. softmax_cross_entropy_with_logits_v2
显然,这个函数时上个函数的改进版,其函数声明如下:
def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):
注意的几点:
- 参数列表中 采用了 axis 而非 dim。
- 省略了 _sentinel ,因此无需明确指出哪一个是 labels ,哪一个是 logits。
- 其余的基本配置与 softmax_cross_entropy_with_logits 相同。
3. sparse_softmax_cross_entropy_with_logits
在这个函数中, sparse 代表 稀疏属性. 这个函数特别适用于处理仅包含一个类别标签的图片, 即 labels 的向量中通常仅有一个元素为1而其余元素均为0. 在MNIST、CIFAR-10等常见的图像分类任务中通常采用了此方法.
def sparse_softmax_cross_entropy_with_logits(
_sentinel=None, # pylint: disable=invalid-name
labels=None,
logits=None,
name=None):
需要注意的一点:
这个函数的 labels 与其前两个函数存在显著差异。其 labels 值为索引型数据,具体指示图片所属类别而不提供分类概率。
The labels vector must require a unique identifier for the target class in each row of the logits tensor.
通常为 [batchsize] 维度。而 logits 类似地具有 [batchsize, numberClass] 维度。
举个例子说明一下:
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.math.argmax(y_, 1),
logits=y
)
# y_ 是 one_hot 形式的标签
# y 是网络预测结果(不经过 softmax 处理)
小结
当 label 数据稀疏时采用 tf.nn.sparse_softmax_cross_entropy_with_logits 并特别提醒:在使用该函数时,请注意 labels 参数的维度设置。对于其他较为复杂的场景,则建议采用 tf.nn.softmax_cross_entropy_with_logits_v2 进行计算
全部评论 (0)
还没有任何评论哟~
