Advertisement

tensorflow交叉熵损失函数

阅读量:

简介

tensorflow 图像分类问题中,常用的交叉熵损失计算有三个函数:

  1. tf.nn.softmax_cross_entropy_with_logits: 该功能模块用于计算交叉熵损失。
  2. tf.nn.softmax_cross_entropy_with_logits_v2: 此功能模块执行交叉熵损失计算,并支持版本号为2的功能增强。
  3. tf.nn.sparse_softmax_cross_entropy_with_logits: 该组件专门用于处理稀疏标签的交叉熵损失计算。

简单记录一下这三个函数的用法,免得自己后面再忘记了。

1. softmax_cross_entropy_with_logits

该函数的声明如下:

复制代码
    def softmax_cross_entropy_with_logits(
    _sentinel=None,  # pylint: disable=invalid-name
    labels=None,
    logits=None,
    dim=-1,
    name=None):

需要注意的几点:

  1. 该函数已不再适用,请使用tf.nn.sparse_softmax_cross_entropy_with_logits_v2版本(V2)。其改进之处在于,在反向传播过程中不仅更新logits相关的参数,同时也更新labels相关的参数(V2版本会在计算过程中同步考虑这两者)。这一设计源于labels往往来源于机器学习模型的预测结果而非人工标注(尤其是在数据量巨大的情况下),因此手动标注会耗费大量的人力资源。
  2. _sentinel的存在主要目的是确保在调用此函数时明确区分并指明哪些是labels参数、哪些是logits参数(例如调用方式应为:tf.nn.softmax_cross_entropy_with_logits(labels=..., logits=...))
  3. logits参数代表的是网络输出前未经Softmax处理的结果(因为此函数内部会自动执行一次Softmax操作)。
  4. labels与"logits"两个张量必须具有相同的形状,在一般情况下均为 [batchsize, numberClass]
  5. labels张量中的每一行都需要满足概率分布的要求(即各行元素之和等于1),而"logits"则无需满足此条件。
  6. dim参数主要应用于高维场景下指定哪一维代表分类结果维度(通常情况下,“labels”与"logits"张量的维度均为 [batchsize, numberClass], 因此该参数一般无需特别关注

2. softmax_cross_entropy_with_logits_v2

显然,这个函数时上个函数的改进版,其函数声明如下:

复制代码
    def softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1, name=None):

注意的几点:

  1. 参数列表中 采用了 axis 而非 dim。
  2. 省略了 _sentinel ,因此无需明确指出哪一个是 labels ,哪一个是 logits。
  3. 其余的基本配置与 softmax_cross_entropy_with_logits 相同。

3. sparse_softmax_cross_entropy_with_logits

在这个函数中, sparse 代表 稀疏属性. 这个函数特别适用于处理仅包含一个类别标签的图片, 即 labels 的向量中通常仅有一个元素为1而其余元素均为0. 在MNIST、CIFAR-10等常见的图像分类任务中通常采用了此方法.

复制代码
    def sparse_softmax_cross_entropy_with_logits(
    _sentinel=None,  # pylint: disable=invalid-name
    labels=None,
    logits=None,
    name=None):

需要注意的一点:

这个函数的 labels 与其前两个函数存在显著差异。其 labels 值为索引型数据,具体指示图片所属类别而不提供分类概率。

The labels vector must require a unique identifier for the target class in each row of the logits tensor.

通常为 [batchsize] 维度。而 logits 类似地具有 [batchsize, numberClass] 维度。

举个例子说明一下:

复制代码
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.math.argmax(y_, 1),
        logits=y
    )
    
    # y_ 是 one_hot 形式的标签
    # y  是网络预测结果(不经过 softmax 处理)

小结

当 label 数据稀疏时采用 tf.nn.sparse_softmax_cross_entropy_with_logits 并特别提醒:在使用该函数时,请注意 labels 参数的维度设置。对于其他较为复杂的场景,则建议采用 tf.nn.softmax_cross_entropy_with_logits_v2 进行计算

全部评论 (0)

还没有任何评论哟~