Advertisement

NIPS2019《Cross Attention Network for Few-shot Classification》

阅读量:

本论文提出了一种改进的小样本分类方法,通过引入交叉注意模块(Cross Attention Module, CAM)和直推推理算法来提升分类性能。主要贡献包括:
交叉注意模块(CAM):用于解决不可见类问题和低数据问题。

  • 不可见类问题:通过计算类特征与查询特征之间的语义相关性,并生成交叉注意力图来突出显示目标对象。
  • 低数据问题:利用无标签查询样本进行直推推理算法以缓解类别分布不一致的问题。
    直推推理算法:通过迭代预测查询样本的标签并选择伪标签样本来扩大支持集,使类别特征更具有代表性。
    实验结果:对比实验表明所提出的方法在SOTA上表现优异,并且消融实验进一步验证了各组件的有效性。
    总结而言,该方法通过增强特征表示能力和优化数据利用方式,在小样本分类任务中取得了显著效果。
在这里插入图片描述

论文已发布于第2019年神经网络与机器学习 sympoium(简称NIPS)!相关研究者可访问以下资源获取完整论文:https://proceedings.neurips.cc/paper/2019/file/01894d6f048493d2cacde3c579c315a3-Paper.pdf;与此同时其对应的GitHub存储库也已托管了完整的代码实现:https://github.com/blue-blue272/fewshot-CAN

1. 动机

在这里插入图片描述

虽然值得期待的是这一领域的发展前景良好,但鲜少有人会对所提取出的关键属性可识别性给予足够的关注与重视,这可能导致一些关键属性缺乏必要的区分度

2. 贡献

在本工作中,提出了一种新的**交叉注意网络(CAN)**来提高小样本分类的特征可鉴别性。
1)首先,引入交叉注意模块(CAM)来解决不可见类问题 。交叉注意的想法是受人类少样本学习行为的启发。为了从一个未被发现的类别中识别出一个样本,人类倾向于首先在标记和未标记的样本对中定位最相关的区域。类似地,给定一个类特征图和一个查询样例特征图,CAM为每个特征生成一个交叉注意图来突出显示目标对象。为了达到这一目的,采用了相关估计和元融合方法。这样可以使测试样本中的目标对象获得注意,交叉注意图加权的特征具有更强的判别性。如图1 (e)所示,利用CAM提取的特征可以对目标物体幕区域进行粗略的定位。
2)其次,我们引入了一个直推推理算法,利用整个无label查询集来缓解低数据问题 。该算法迭代预测查询样本的标签,并选择伪标签查询样本来扩大支持集。每个类支持样本越多,得到的类特征就越有代表性,从而缓解了低数据问题。

3. 方法

3.1 问题定义

少样本分类通常包括一个训练集、一个支持集和一个查询集。训练集包含大量的类和标注的样本。少数标注样本的支持集和无标注样本的查询集共享同一个标注空间,而标注空间与训练集的标注空间是不相连的。少样本分类的目的是对给定训练集和支持集的无标记查询样本进行分类。如果支持集由C类和每个类的K个标记样本组成,则目标少样本问题称为C-way K-shot。
根据已有经验,本文也采用episode训练机制,该机制已被证明是一种有效的少样本学习方法。训练中使用的episode模拟了测试中的设置。每个episode是由随机抽样C类和每个类K个标记样本作为支持组\mathcal{S} = \{ (x^s_a, y^s_a)\}^{n_s}_{a=1} (n_s = C \times K)C类中剩余样本的一小部分作为查询集\mathcal{Q} = \{ (x^q_b, y^q_b)\}^{n_q}_{b=1}组成。我们将\mathcal{S}^k表示为第k类的支持子集。如何表示每个支持类\mathcal{S}^k和查询样本x^q_b,并度量它们之间的相似性是少样本分类的关键问题。

3.2 Cross Attention Module

在本工作中利用度量学习从每一对支持类与查询样本中提取合适的特征表示。本文开发了交叉注意模块(Cross Attention Module, CAM),该模块能够建立类特征与查询特征间语义关联模型,在引起目标物体注意力的同时提升了后续匹配效果。

在这里插入图片描述

如图所示的(a)图中所展示的是CAM机制。类特征映射P^k \in \mathbb{R}^{c \times h \times w}基于支持样本\mathcal{S}^k (k \in \{ 1, 2, \cdots, C\})提取得到,而查询特征映射Q^b \in \mathbb{R}^{c \times h \times w}则基于查询样本x^q_b (b \in \{ 1, 2, \cdots, n_q\})提取得到。其中chw分别代表通道数、高度和宽度这三个维度参数。该模型通过将位置相关的关注矩阵A^p (A^q)作用于对应的特征图来进行加权操作,在这一过程中实现了更具判别性的特征表示\bar{P}^k_b (\bar{Q}^b_k)生成。为了简化表达形式后文将输入类和查询特征映射统一表示为P和Q而输出类与查询相关联的特征表示则统一用\bar{P}\bar{Q}来表示

  • Correlation Layer
    如图(a)所示,首先设计一个Correlation Layer来计算PQ之间的关联图,然后利用关联图来指导交叉注意图的生成。为此,我们首先将PQ重塑为\mathbb{R}^{c \times m},即P = [p_1, p_2,…, p_m]Q = [q_1, q_2, \cdots, q_m],其中m (m = h \times w)为每个feature map上的空间位置个数。p_i, q_i \in \mathbb{R}^c分别是PQ中第i个空间位置的特征向量。关联层用余弦距离计算\{p_i \}^m_{i=1}\{q_i \}^m_{i=1}之间的语义相关性,得到关联映射R∈ \mathbb{R}^{m \times m}为:
在这里插入图片描述

在此基础上

  • Meta Fusion Layer
在这里插入图片描述

通过元融合层对相应的关联图进行分析后分别构建类注意力图和查询注意力图。以类注意力图为示例说明其工作原理:首先将输入特征映射到类相关空间中得到R^p矩阵;随后采用m\times1尺寸的卷积核对每一列进行计算得到m个局部相关向量\{r^p_i\};接着将每个局部相关向量r^p_i通过加权求和的方式转化为一个标量值;最后对所有标量值应用softmax函数进行归一化处理得到第i个位置的类注意力权重值

在这里插入图片描述

其中\tau代表温度超参数,在温度参数较低时会降低系统的熵值量(entropy measure),从而使得概率分布集中在有限个高置信度的类别范围内(class)。为了便于后续处理过程,在矩阵空间\mathbb{R}^{h \times w}中重新构造原始矩阵A^p(matrix A^p),从而得到类注意映射(class attention mapping)。值得注意的是,在特征融合过程中(feature fusion process),核心注意力机制组件(core attention mechanism component)起着关键作用(key role)。对于每个位置i而言,在集合局部类特征p_i与所有局部查询特征\{q_j\}_{j=1}^m之间计算出的相关系数可被视为该位置的注意力权重标量(attention scalar value)。值得注意的是,在加权聚合操作中(weight aggregation operation),注意力机制应聚焦于目标物体区域(target object region),而非仅仅突出显示跨支持类和查询示例之间的视觉相似区域(visually similar regions across support classes and query examples)。
基于以上分析与设计思路,在本文中我们提出了一种元学习器框架:通过分析类与查询特征之间的相关程度来自适应地生成内核向量(kernel vector)。为此,在矩阵空间\mathbb{R}^{h \times w}上应用全局平均池化操作(GAP operation),即对每一行求取平均值得到一个全局查询相关向量(global query-relevant vector),然后将该向量输入到元学习器中生成核心向量w \in \mathbb{R}^m

w = f_{\text{MLP}}(\text{GAP}(A^p))

其中f_{\text{MLP}}(\cdot)代表多层感知机网络层。

在这里插入图片描述

其中W_1 \in \mathbb{R}^{\frac{m}{r} \times m}W_2 \in \mathbb{R}^{m \times \frac{m}{r}}是元学习者的参数,r是还原比。σ为ReLU函数。元学习模型的非线性允许灵活的转换。对于每对类和查询特征,元学习器期望生成一个能引起对目标对象交叉注意的核心w。这是在元训练中通过最小化查询样本的分类错误来实现的。
以类似的方式,我们可以查询注意力图A^q \in \mathbb{R}^{h \times w}。最后,我们用剩余的注意机制,在初始特征图PQ 通过1 + A^p1 + A^q进行元素加权以分别形成更有识别力的特征图\bar{P}\in \mathbb{R}^{c \times h \times w}\bar{Q} \in \mathbb{R}^{c \times h \times w}

3.3 Cross Attention Network

在这里插入图片描述

交叉注意网络体系主要包含三个核心组件:嵌入模块、交叉注意力机制以及分类器。嵌入模块E则由一系列连续卷积层构成,在输入图像x的基础上生成其对应的特征表示E(x)∈ℝ^{c×h×w}。与prototypical network一致,在本文设定中类特征被定义为其支持集在嵌入空间中的平均值。如图3所示,在这一过程中,嵌入模块E会将支持集S以及一个查询样本x_q_b作为输入参数,并分别生成类特征映射P_k=1/|S_k| Σ_{x_s_a∈S_k} E(x_s_a)以及查询样本对应的查询特征映射Q_b=E(x_q_b)。在此基础上,在每一对特征映射(P_k与Q_b)之间都会激活交叉注意力机制进行处理,在这一过程中系统会自动提取出更具区分度的对应关系(¯P_k_b与¯Q_b_k),从而完成分类任务

通过最小化训练集上的查询样本分类损失来进行CAN的训练。该体系中的分类模块由局部近邻和全局分类器构成。其中,在每个支持类别c\mathcal{C}上分配一个概率值p_c(q^b_i)来表示第i个位置局部查询特征q^b_i被归类到该支持类别的可能性大小。具体而言,在第i个位置的局部查询特征q^b_i经过最近邻分类器后会生成一个带有softmax性质的概率分布向量\bm{p}^c_i ∈[0,1]^{\mathcal{C}}满足\sum_{c=1}^{|\mathcal{C}|}p_c(q^b_i)=1。其中,在第k个类别中对q^b_i的概率预测为:

在这里插入图片描述

其中(\bar{Q}_k^{b})_i 具体指代的是该序列在第i 个空间维度上的特征向量;GAP运算则用于从整个序列中提取全局平均类特征的操作步骤。值得注意的是,在此过程中,\bar{Q}_k^{b}\bar{Q}_j^{b} 共同代表查询样本x_q^{b} 的两个不同支持类别相关的子序列;因为它们分别对应于不同支持类别中的样本,所以在式(4)中,余弦距离d 被计算于CAM生成的特征空间之中;随后基于真实的类别标签

在这里插入图片描述

该分类器通过引入一个全连接层来构建特征表示,并借助softmax函数对各类别概率进行建模。假设训练数据集包含l个不同的类别标签,则对于每一个局部查询特征q^b_i而言,在经过线性变换后得到的概率向量z^b_i \in \mathbb{R}^l可以通过以下公式计算得到:z^b_i = \text{softmax}(W_c(\bar{Q}^b_{y^q_b})_i)。在此基础上构建的全局分类损失函数能够有效反映模型预测结果与真实标签之间的差异程度。

在这里插入图片描述

其中全连接层的权重参数W_c\in\mathbb{R}^{l\times c}属于实数矩阵空间。对于输入样本x^q_b而言,其真实全局类别标记为l^q_b\in\{1,2,\dots,l\}。为了整合不同损失项的影响程度,在分类任务的整体损失函数定义为L=λL_1+L_2的基础上引入平衡参数λ。通过梯度下降方法最小化L的目标函数,并采用端到端训练的方式对整个网络模型进行优化以提升性能。

  • 归纳推理
    在归纳推理环节中, 直接将嵌入模块应用于新任务, 提取了类别映射与查询特征映射. 随后将每对类别与对应的查询特征映射输入至CAM模块中, 从而获得经过注意力加权后的特征向量. 并并对CAM输出执行全局平均池化操作, 从而得到整体类别与查询特性的均值表示. 最后, 在余弦距离度量下进行比较后得出结果.
在这里插入图片描述

一种基于直推式的推理机制,在针对仅标注少量实例的分类问题中展现出显著的有效性。针对这种特定场景下的挑战,在每个类别仅标注了极少数实例的情况下(这使得类别特征难以充分反映真实类别分布的情况),我们提出了一种高效的解决方案来处理直推式推理问题。该方法通过无标签查询样本的数据来补充类别特征信息,并在此过程中实现了对分类任务的支持效率提升。

在这里插入图片描述

4. 部分实验结果

SOTA结果对比

在这里插入图片描述

消融实验

在这里插入图片描述
在这里插入图片描述

可视化

在这里插入图片描述

5. 结论

  • CAM的时间与空间成本主要体现在relation层。CAM的时间复杂度为O(h^2w^2c),空间复杂度为O(hwc)。由于二者均受输入特征图尺寸的影响,在本文中,在最后一个卷积层之后加入了CAM模块以优化资源分配。
  • 本文所提出的归纳与直推式训练方法能够显著提升性能水平。

全部评论 (0)

还没有任何评论哟~