Advertisement

【CVPR 2021】基于样本间关系的知识蒸馏:Complementary Relation Contrastive Distillation

阅读量:

CVPR 2021

CVPR 2021

论文链接:
主要问题:
基本思路:
算法优势:
论证分析:
基本符号:
优化目标:
优化目标的下限:
分布近似:
关系对比损失:
互补关系:
关系度量:

  • 具体实现:
    • 评判函数:
    • 采样策略:
    • 损失函数:

论文地址:

https://arxiv.org/abs/2103.16367

主要问题:

知识蒸馏主要关注每个样本的特征独立性(即让teacher和student模型对每个样本的特征进行近似),然而,作者进一步认识到不同样本之间的特征距离同样承载着重要的结构化信息(如图所示),因此提出了CRCD蒸馏算法。

在这里插入图片描述

主要思路:

在蒸馏过程中,对每个样本随机采样其neighbor样本,构建其对应的anchor-teacher和anchor-student关系,分别用于teacher model和student model的特征关系表示。在蒸馏阶段,通过设计机制促进这两个关系保持一致。其中,特征关系不仅包含特征本身,还同时利用特征的梯度信息进行表示,通过分别设计的子网络MT和MT,S进行学习,最终利用Relation Contrastive Loss损失函数来保持其一致关系(如图所示)。

在这里插入图片描述

算法优点:

a) 能够自动优化样本的特征表示以及样本间关系;
b) KD类蒸馏算法使用student-student关系表达样本间关系,这种表示很不稳定,因为student models都是同时训练而且没有很好地优化,因此使用anchor-student的表示更为合理;
c) 由于训练时anchor可以从某个样本的neighbor中随机选取,增加了蒸馏模型的鲁棒性;

算法论证:

基本符号:

教师/学生模型:Ω^T/Ω^S

输入x的输出:\phi^T(x)/\phi^S(x)

样本集Xx_i,x_j在教师模型中的关系(anchor-teacher):r^T_{i,j}

其中该关系通过子网络M^T得到:r^T_{i,j}=M^T(\phi^T(x_i),\phi^T(x_j))

同样r^{T,S}_{i,j}(anchor-student)通过子网络M^{T,S}得到:r^{T,S}_{i,j}=M^{T,S}(\phi^T(x_i),\phi^S(x_j))

我们将\phi^T(x_i)视为一个锚点表示。因此,r^T_{i,j}r^{T,S}_{i,j}应该尽可能保持一致。这不仅有助于保留x_ix_j之间的关系信息,同时也促使\phi^S(x_j)\phi^T(x_j)保持一致。

简单起见,我们将anchor-reacher和anchor-student分别记为:R^TR^{T,S}

优化目标:

如果我们用P(R^T)P(R^{T,S})来近似于R^TR^{T,S}在样本集X下的条件概率分布P(R^T|X)P(R^{T,S}|X),那么我们的优化目标就是使R^TR^{T,S}的概率分布尽可能接近,即最大化这两个概率分布之间的互信息(MI):

I(R^T,R^{T,S})=\mathbb{E}_{p(R^T,R^{T,S})}\log\frac{p(R^T,R^{T,S})}{p(R^T)p(R^{T,S})}

优化目标的下限:

我们引入一个带有隐变量C的分布q

C=1表示r^Tr^{T,S}是通过相同的样本x_i,x_j计算得到的,即:

x_i,x_j \sim p(X)

r^T_{i,j}=M^T(\phi^T(x_i),\phi^T(x_j))

r^{T,S}_{i,j}=M^{T,S}(\phi^T(x_i),\phi^S(x_j))

C=0表示r^Tr^{T,S}是分别通过独立样本x_i,x_jx_m,x_n计算得到的,即:

x_i,x_j \sim p(X)

r^T_{i,j}=M^T(\phi^T(x_i),\phi^T(x_j))

x_m,x_n \sim p(X)

r^{T,S}_{m,n}=M^{T,S}(\phi^T(x_m),\phi^S(x_n))

那么我们可以记:

q(R^T,R^{T,S}|C=1)=p(R^T,R^{T,S})

q(R^T,R^{T,S}|C=0)=p(R^T)p(R^{T,S})

我们假设1个相关关系对(C=1)带有N个不想关关系对(C=0),那么q(C=1)=1/(N+1),q(C=0)=N/(N+1)

基于贝叶斯先验概率,我们可以得出C=1的后验概率为:q(C=1|R^T,R^{T,S})=\frac {p(R^T,R^{T,S})}{p(R^T,R^{T,S})+Np(R^T)p(R^{T,S})}

结合MI可以得到:

\log q(C=1|R^T,R^{T,S}) \leq-\log(N)+\log(\frac{p(R^T,R^{T,S})}{p(R^T)p(R^{T,S})})

对两边关于p(R^T,R^{T,S})同时取期望(等价于q(R^T,R^{T,S}|C=1)),我们可以得到:

I(R^T,R^{T,S})\geq\log(N)+\mathbb{E}_{q(R^T,R^{T,S}|C=1)}\log q(C=1|R^T,R^{T,S})

分布近似:

鉴于真实分布难以确定,作者基于采样技术,构建一个子网络模型h:\{R^T,R^{T,S}\}\rightarrow[0,1],用于近似计算q(C=1|R^T,R^{T,S})的后验概率。

该模型下抽样数据的对数似然函数就可以定义为:

\mathcal{I}(h)=\mathbb{E}_{q(R^T,R^{T,S}|C=1)}[\log h(R^T,R^{T,S})]+N\mathbb{E}_{q(R^T,R^{T,S}|C=0)}[\log (1-h(R^T,R^{T,S}))]

因此,为了实现对上述分布q(C=1|R^T,R^{T,S})的较为准确的近似目标,我们需要优化上述的对数似然函数。

考虑+右边总是小于等于0,我们可以得到:

I(R^T,R^{T,S})至少等于\log(N),加上平均地,基于条件概率q(R^T,R^{T,S}|C=1)的对数概率,再加上基于条件概率q(R^T,R^{T,S}|C=0)的加权对数损失。

即:

I(R^T,R^{T,S})\geq\log(N)+\mathcal{I}(h)

通过优化目标等价于在参数化模型h的框架下最大化下界值\log(N)+\mathcal{I}(h)

关系对比损失:

在作者的方法中,函数h的输入项r^T(教师模型间的空间关系)和r^{T,S}(跨空间的关联)是由教师模型与学生模型Ω^TΩ^S,以及两个子网络M^TM^{T,S}构成。

其中Ω^SM^{T}M^{T,S}都需要在蒸馏过程中优化

如前所述,我们的目标旨在提升相关信息量,等同于最小化关系间的对比差距(relation contrastive loss),记作:

\mathcal{L_{RC}}(h,Ω^S,M^{T},M^{T,S})=-\sum_{q(C=1)}\log h(r^T,r^{T,S})-N\sum_{q(C=0)}\log [1-h(r^T,r^{T,S})]

其中\{((r^T,r^{T,S}|C=1)\}是正关系对,\{((r^T,r^{T,S}|C=0)\}是负关系对

根据上文不近似,通过最小化关系对比损失来近似地遵循该分布q(C|R^T,R^{T,S}),从而提升这些网络相关信息的下界。同时,对这三个网络进行联合优化。

互补关系:

子网络M^{T,S}用来计算表示\phi^T(x_i),\phi^S(x_j)的anchor-student关系:

r^{T,S}_{i,j}=M^{T,S}(\phi^T(x_i),\phi^S(x_j))

即:

r^{T,S}_{i,j}=W^{A}(\sigma(W^A_i\phi^T(x_i)-W^A_j\phi^S(x_j)))

其中,W^{A}_iW^{A}_j是用于线性变换操作的参数,其主要作用是解决输入与输出维度不匹配的问题;\sigma被定义为ReLU激活函数,用于引入非线性特性;同时,W^{A}也被用于进行变换操作,以确保模型的参数共享机制能够得到有效应用。

通过这种方式,r^{T,S}_{i,j}可以利用子网络M^{T}的输出r^{T}_{i,j},其中锚点-教师机制(anchor-teacher)被采用,进行监督式的学习:

r^{T}_{i,j}=W^{B}(\sigma(W^B_i\phi^T(x_i)-W^B_j\phi^S(x_j)))

关系度量:

注意对于\phi^T(x)/\phi^S(x),作者既使用了特征f,又使用了梯度g

对于特征:

\phi^T(x)/\phi^S(x)直接对应于输入x的激活f^T(x)/f^S(x)

对于梯度:

g(x)=\frac{\partial}{\partial f}L_{cls}(\Omega,x) 反映了优化的动力学方向

\phi^T(x)/\phi^S(x)对应于输入x的梯度g^T(x)/g^S(x)

具体实现:

评判函数:

参数化模型h用于判断关系对(r^T,r^{T,S})是否属于同一个联合概率分布p(R^T,R^{T,S}),而不是边缘分布的乘积p(R^T)p(R^{T,S})

这中表达方式跟NCE很类似:

h(r^T,r^{T,S})=\frac{e^{h_1(r^T)h_2(r^{T,S})/\tau}}{e^{1/\tau}}

其中 \tau 是温度超参数,h1h2操作是一个线性变换加一个\mathcal{l2}正则

采样策略:

在每次正向传播过程中,通过当前mini-batch中的任意两个样本x_ix_j,可以计算得到anchor关系r^{T}_{i,j}和正向关系r^{T,S}_{i,j};而对于负向关系r^{T,S}_{i,k},则基于x_i的anchor表示以及缓存区中第k个表示进行计算。

在处理过程中,我们设置mini-batch大小为B,这样对于每个样本对,我们需要统计B^2个关系;此外,从缓冲区中随机选取N个负样本,用于对比学习任务。

考虑到随机采样在映射当前网络状态方面存在局限性,作者开发了一种队列采样策略。该策略不仅存储了之前在前向传播过程中记录的N个样本索引,而且在每次前向传播结束后,会将当前的mini-batch数据替换掉最旧的索引,以确保队列的更新和数据的有效利用。

损失函数:

作者同时也加入了原始知识蒸馏的KD损失:

\mathcal{L_{kd}}=\rho^2\mathcal{H}(\sigma(z^T/\rho),\sigma(z^S/\rho))

其中 \mathcal{H} 是交叉熵损失,\sigmasoftmax激活函数

这样完整的损失函数就可以写作:

该损失函数组合包含了四个部分:分类分支损失、知识分支的Distillation损失,以及分别对应于RC分支的前向和后向损失。具体来说,损失函数的计算公式为:\mathcal{L}=\mathcal{L_{cls}}+\alpha\mathcal{L_{kd}}+\beta_1\mathcal{L_{RC}^f}+\beta_2\mathcal{L_{RC}^g}。其中,\mathcal{L_{cls}}代表分类分支损失,\alpha\mathcal{L_{kd}}代表知识分支的Distillation损失,\beta_1\mathcal{L_{RC}^f}\beta_2\mathcal{L_{RC}^g}分别对应于RC分支的前向和后向损失。

其中\alpha,\beta_1,\beta_2是超参数,默认为:1,0.5,0.5

在这里插入图片描述

全部评论 (0)

还没有任何评论哟~