《Query2Label: A Simple Transformer Way to Multi-Label Classification》

论文链接:https://arxiv.53yu.com/pdf/2107.10834.pdf?ref=https://githubhelp.com
代码链接:https://github.com/SlongLiu/query2labels
1. 动机
多标签分类需要特别关注两个问题:1)如何处理标签不平衡问题,2)如何从感兴趣区域提取特征。为了解决第一个问题,很多研究者设计了各种损失函数;而相对于第一个问题相比,第二个问题的解决方案相对不成熟,需要特殊设计的网络架构或额外依赖于标签相关性。
2. 方法

在本文中,作者提出了一个简单而有效的解决方案,使用Transformer解码器查询类标签的存在性。作者表明,在没有附加功能的情况下,所提出的解决方案会产生新的SOTA结果,并为其简单的实现和卓越的性能建立强大的基线。该解决方案命名为Query2Label ,如上图所示1,本文使用可学习的标签嵌入作为查询,通过Transformer编码器中的交叉注意模块探测和汇集类相关的特性。合并后的特征具有自适应和更强的鉴别性,从而提高了多标签分类的性能。
-
Query2Label框架

Query2Label是一个两阶段框架,主要由第一阶段的特征提取模块 ,第二阶段的Transformer解码块 (用于query更新)和特征映射 组成。 -
特征提取模块
给定一幅图像x \in R^{H_0 \times W_0 \times 3}作为输入,利用主干提取其空间特征\mathcal{F}_0 \in R^{}H \times W \times d_0。然后添加一个线性映射层,将特征从维数d_0投影到d,与第二阶段所需的query维数匹配,并将映射的特征reshape为\mathcal{F} \in R^{H \times W \times d}。 -
query更新
在第一阶段获得输入图像的空间特征后,使用标签嵌入作为查询\mathcal{Q}_0 \in R^{K \times d},并使用多层Transformer解码器的cross-attention来聚合空间特征中类别相关特征,其中K为类别数。并使用标准的Transformer体系结构,它有一个self-attention模块、一个cross-attention模块和一个位置前馈网络(FFN)。Transformer解码器每层(i)从其上一层的输出更新query \mathcal{Q}_{i-1},即:

其中波浪线表示通过添加位置编码修改的原始矢量。由于我们不需要执行自回归预测,所以我们没有使用attention masks。这样,每一层都可以并行地对M类进行解码。 -
特征映射
假设一共有L个层,对于最后一层的K个类别,我们将得到被查询到的特征向量\mathcal{Q}_L \in R^{K \times d}。为了进行多标签分类,将每个标签预测视为一个二元分类任务,并使用一个线性映射层和一个sigmoid函数将每个类\mathcal{Q}_{L, k} \in R^d的特征映射成一个logit值,即

其中W_k \in R^d,W = [W_1, \cdots, W_K]^T \in R^{L \times d},且b_k \in R,\ b=[b_1, \cdots, b_K]^T \in R^K是线性层的参数。p=[p_1, \cdots, p_K]^T \in R^K是每个类的预测概率。 -
损失函数
为了更有效地解决样品不平衡问题,本文采用了一个简化的非对称损耗,它是一个对于正值和负值有不同\gamma值的焦点损失。在本文的实验中发现它的效果最好。给定一个输入图像x,可以使用本文框架预测它的分类概率p=[p_1, \cdots, p_K]^T \in R^K。然后利用下面的非对称焦点损失来计算每个训练样本x的损失:

其中y_k为二值标签,表示图像x是否有标签k。总损失的计算方法是对训练数据集\mathcal{D}中的所有样本的损失的平均值,并使用随机梯度下降法进行优化。默认情况下,在实验中设置\gamma += 0和\gamma = 1。
3. 部分实验结果
实验主要在四个库上进行,分别为:MS-COCO, PASCAL VOC, NUS-WIDE, 和Visual Genome。
-
与最新方法对比




-
消融实验
不同大小的目标的结果

-
attention map的可视化



4. 结论
本文提出了一个简单而有效的多标签分类框架Query2Label (Q2L),它是基于在图像分类主干上的Transformer解码器开发的。Transformer解码器体系结构中的内置cross-attention模块提供了一种有效的方法,可以使用标签嵌入来查询类标签和池类相关特性的存在性。在MS-COCO、PASCAL VOC、NUSWIDE和Visual Genome等几个广泛使用的数据集上,该框架的性能始终优于所有之前的工作。
