Advertisement

Focal and Global Knowledge Distillation for Detectors--FGD论文解读

阅读量:

论文:Focal and Global Knowledge Distillation for Detectors

论文:https://arxiv.org/abs/2111.11837

代码:https://github.com/yzd-v/FGD

一,针对问题

1. 目标检测中前背景不平衡问题

知识蒸馏的目标是让学生模仿教师的行为以获取知识以便生成与教师相同的结果从而提高自身的性能水平。研究者首先通过可视化手段展示了学生和教师在特征图上的异同从可视化结果可以看出 在空间和通道注意力机制方面 学生与教师之间存在明显的差异其中 在空间注意力机制中 学生与教师在前景区域的表现差异较大而在背景区域则表现较为接近 这将导致不同层次的学习挑战

为了进一步探讨前景与后景在知识蒸馏中的作用, 研究者将前景与后景分开进行蒸馏实验. 若将全图统一进行蒸馏, 则会导致蒸馏性能显著下降. 分离前后景可使处理效果更加理想.

考虑到学生与教师注意力之间的差异性以及前景与背景之间的异质性, 作者提出了一种叫做Focal Distillation的重点蒸馏方法: 首先它通过实现前背景信息的有效区分, 并结合教师模型在空间维度上的位置感知能力和通道维度上的特征提取能力, 最终指导学习者完成知识迁移过程, 同时构建了基于区分度加权的重点蒸馏损失函数

二,方法

整体蒸馏损失计算方式:

C,H,W:feature map的通道时和高宽。

F^T

F^{S}

为教师和学生模型的输出。

2.1 分离前背景

前、背景Mask

设置一个二值MASK:

r代表GT bbox,如果feature map的点落在bbox内则该点为1,否则为0.

2.2 尺度

尺度Mask

大小目标focal,前、背景

Hr和Wr分别代表bounding box的高度和宽度,在存在遮挡的情况下(即同一个体若同时被归类到多个目标类别中),我们会优先采用具有最小面积的目标区域来计算S

2.2 空间与通道注意力

空间与通道注意力

C,H,W:feature map的通道时和高宽。

G

G^S ,G^{C}

代表空间注意立和通道注意力机制,

Attention MASK:

T为蒸馏温度 ,论文设置为0.5

2.3 全局蒸馏

全局信息的丢失

Focal Distillation将前景和背景分别进行蒸馏,在切断前后景之间的联系中丢失了特征级别的整体信息。为此提出了一种基于全局蒸馏的方法:通过GcBlock分别提取学生与教师的信息,并用于计算整体蒸馏损失。

通过GCBlock获取全局信息,并使学生模型能够在教室模型的基础上学习前背景之间的联系

损失计算如下:

复制代码
  
    
     self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
    
     self.channel_add_conv_s = nn.Sequential(
    
         nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
    
         nn.LayerNorm([teacher_channels//2, 1, 1]),
    
         nn.ReLU(inplace=True),  # yapf: disable
    
         nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
    
     self.channel_add_conv_t = nn.Sequential(
    
         nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
    
         nn.LayerNorm([teacher_channels//2, 1, 1]),
    
         nn.ReLU(inplace=True),  # yapf: disable
    
         nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
    
  
    
     def spatial_pool(self, x, in_type):
    
     batch, channel, width, height = x.size()
    
     input_x = x
    
     # [N, C, H * W]
    
     input_x = input_x.view(batch, channel, height * width)
    
     # [N, 1, C, H * W]
    
     input_x = input_x.unsqueeze(1)
    
     # [N, 1, H, W]
    
     if in_type == 0:
    
         context_mask = self.conv_mask_s(x)
    
     else:
    
         context_mask = self.conv_mask_t(x)
    
     # [N, 1, H * W]
    
     context_mask = context_mask.view(batch, 1, height * width)
    
     # [N, 1, H * W]
    
     context_mask = F.softmax(context_mask, dim=2)
    
     # [N, 1, H * W, 1]
    
     context_mask = context_mask.unsqueeze(-1)
    
     # [N, 1, C, 1]
    
     context = torch.matmul(input_x, context_mask)
    
     # [N, C, 1, 1]
    
     context = context.view(batch, channel, 1, 1)
    
  
    
     return context
    
  
    
    
    
     def get_rela_loss(self, preds_S, preds_T):
    
     loss_mse = nn.MSELoss(reduction='sum')
    
  
    
     context_s = self.spatial_pool(preds_S, 0)
    
     context_t = self.spatial_pool(preds_T, 1)
    
  
    
     out_s = preds_S
    
     out_t = preds_T
    
  
    
     channel_add_s = self.channel_add_conv_s(context_s)
    
     out_s = out_s + channel_add_s
    
  
    
     channel_add_t = self.channel_add_conv_t(context_t)
    
     out_t = out_t + channel_add_t
    
  
    
     rela_loss = loss_mse(out_s, out_t)/len(out_s)
    
     
    
     return rela_loss
    
  
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/gZOumyTCdNcD3G8pVtjP2oar6XLM.png)

2.4 最终Loss

alpha=0.001,beta=0.0005

除此之外,利用

L_{at}

通过注意力损失项引导学生模型模仿教师模型的空间和通道注意力Mask。

gamma=0.0005.

最终loss

lambda=0.000005

关于超参

最终效果:

全部评论 (0)

还没有任何评论哟~