Advertisement

Featured Based知识蒸馏及代码(3): Focal and Global Knowledge (FGD)

阅读量:

文章目录

    • 摘要
  • 蒸馏过程中的核心原理主要包含基于焦点与全局策略的两种主要方法。

  • 基于特征的传统蒸馏方法是通过将 teacher 的输出与 student 的特征进行匹配来实现知识转移。

  • 基于焦点蒸馏方法则侧重于在 teacher 输出空间中提取关键特征进行学习。

  • 基于全局蒸馏策略则考虑了 teacher 输出的整体分布特性。

  • 在蒸馏过程中采用的整体损失函数能够有效平衡各层次信息的学习需求。

    • 3. 实验
    • 完整代码
在这里插入图片描述
在这里插入图片描述

论文: https://arxiv.org/pdf/2111.11837.pdf

本文阐述了我们于CVPR2022会议上发布的目标检测知识蒸馏技术:焦点型与全局型知识蒸馏方法及其在单阶段与双阶段检测器中的应用。该方法仅需30行代码即可实现对anchor-base与anchor-free检测器的稳定性能提升。

1. 摘要

知识蒸馏在图像分类领域已展现出显著成效。然而,在目标检测领域中存在诸多挑战。我们发现,在目标检测任务中教师网络与学生网络在不同区域的特征差异显著。为了平衡这种差异带来的负面影响, 我们提出了一种新的蒸馏方法,即Focal and Global Distillation (FGD)策略,该方法通过将图像划分为前景区域和背景区域来实现对教师网络的关注度提升。具体而言,Focal 蒸馏将重点放在前景区域的关键像素上,而Global蒸馏则负责重建各像素间的空间关系并实现知识迁移,从而弥补Focal蒸馏在全局信息方面的不足。值得注意的是,我们提出的方法基于特征导向策略,因此能够广泛应用于各种目标检测模型中。我们在RetinaNet、Faster RCNN以及RepPoints等主流目标检测器上进行了大量实验验证,结果表明采用FGD策略后学生网络在mAP指标上均实现了明显提升(分别提升了约3.3%、3.6%及3.4个百分点)。

|

|

|
|---|---|

知识蒸馏技术能够展示其迁移能力,并将其复杂的教师网络的知识转移到紧凑轻量的学生网络中,并且无需额外推理成本。众所周知,在目标检测性能方面,前景与背景类别不平衡可能会影响蒸馏效果。

  • 如图1所示,在图1中展示的是特征对比可视化结果。从图中可以看出,在学习前景的关注焦点上两者存在显著差异,在学习背景的关注焦点上差异相对较小。这种差异性导致了学习前景与学习背景所面临的困难程度不同。
    • 为了深入探究前景与背景对知识蒸馏的影响机制,我们将前背景样本单独进行蒸馏实验分析。实验结果表明,未对前后背景进行区分处理时,蒸馏效果最差,其性能明显低于仅基于前景或仅基于背景进行蒸馏的效果;同时从图1可以看出,在像素级和通道级上均存在显著差异区域,这种不均衡性会导致蒸馏效果下降。针对这一问题,我们提出了一种新的蒸馏方法叫做focal蒸馏,即学生网络需重点关注教师网络的关键像素区域和通道连接。
    • 然而,仅有关键信息不足以满足需求,全局上下文信息在目标检测任务中同样扮演着重要角色。为弥补Focal蒸馏方法所带来的全局信息丢失问题,本文又提出了一种名为Global蒸馏的新方法。该方法通过GcBlock模块提取样本间关系并将其迁移到学生网络中。

主要创新点

  • (1) 研究发现教师和学生网络在聚焦于像素和通道上存在显著区别。若采用无差别蒸馏策略对所有像素及channels进行训练,则难以获得显著提升效果。
    • (2) 开发出了focal和global蒸馏方法,这些方法使学生不仅能够聚焦于教师网络的关键像素和通道及其间的关联性。
    • (3) 通过系统性实验分析验证了本文提出的蒸馏方法具有良好的适用性,在包含1996年的目标检测任务中均能达成这一目标。

2. Focal and Global 蒸馏的原理

在这里插入图片描述

图 2 FGD框架

FGD包含两种蒸馏方式:焦点蒸馏和全局蒸馏。焦点蒸馏不仅能够将前景区与背景区区分开来,并且能够聚焦于教师网络的特征图上提取的核心特征。全局蒸馏则弥补了教师网络与学生网络在全局语境信息上的差异

2.1 常规的feature based蒸馏算法

通常采用多尺度特征融合技术的目标检测器普遍运用了FPN架构,在此架构下,FPN通过将不同分辨率的空间特征进行深度学习融合,其在知识提取与迁移方面展现出显著优势,从而使得基础学习模型的知识能够被高效地迁移至目标检测任务中,并在此过程中实现了性能表现的有效提升.在现有研究中,基于特征提取的知识蒸馏方法通常可表示为:

在这里插入图片描述

其中F^TF^S分别代表教师网络的特征和学生网络的特征,在f的作用下(或完成),通过调整层使其形状与教师网络的feature map保持一致。

然而这种蒸馏方法未对背景区域和前景区别的关注进行重视,并未能充分捕捉不同像素之间的全局关联信息。

2.2 Focal Distillation

(1) 分离前景和背景

因为样本在前后景与后景之间存在不平衡的问题,在本文中提出了一种名为focal蒸馏的方法来进行区分,并指导学生网络更加注重关键的pixels以及channels之间的关系;随后采用一种基于二值Mask的技术来进行具体分割。

在这里插入图片描述

我们用符号r来表示'ground truth'区域,并将(i,j)视为特征图上的坐标点。当该位置位于'GT'区域内时,M_{i,j}设值为1;否则设值为0。

因为图片存在尺寸差异的原因是由于图像分辨率的不同而导致的。然而,在实际应用中发现这一问题可能导致计算损失显著增加,并且各幅图像之间"前景与背景"的比例也不尽相同。基于此考虑,在模型设计阶段我们引入了调节参数来平衡各方面的关系,并最终确定一个适合整体表现的最佳'Mask'比例值S作为调节因子。

在这里插入图片描述

当一个像素被包含在多个ground truth中时,在计算S的过程中,默认采用最小区域作为参考。

复制代码
    area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))
    
    for j in range(len(gt_bboxes[i])):
    Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
       torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])
    
    Mask_bg[i] = torch.where(Mask_fg[i]>0, 0, 1)
    if torch.sum(Mask_bg[i]):
    Mask_bg[i] /= torch.sum(Mask_bg[i])

(2) 提取关键的pixes和channels

通常会采用基于自适应感知器(SENet)与通道注意力机制(CBAM)相结合的方式,在卷积神经网络模型中施加空间注意力机制并结合通道注意力机制。从而更加注重那些关键像素(pixels)及其关联的通道(channels)。在此基础上,我们借鉴这一思路,在模型架构设计上进行优化与改进。通过这种方式能够有效聚焦于那些需要重点关注的关键像素及其相关通道。

在这里插入图片描述
  • G^SG^C分别表示空间和通道的注意力map

然后将空间注意力map与通道注意力map结合,并采用softmax方法生成概率分布表,并采用参数T来扩展其影响范围。

在这里插入图片描述
复制代码
     def get_attention(self, preds, temp):
        """ preds: Bs*C*W*H """
        N, C, H, W= preds.shape
    
        value = torch.abs(preds)
        # Bs*W*H
        fea_map = value.mean(axis=1, keepdim=True)
        S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)
    
        # Bs*C
        channel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)
        C_attention = C * F.softmax(channel_map/temp, dim=1)
    
        return S_attention, C_attention

首先计算S_attention(空间注意力):

  • Spatial-based Attention(空间注意力)是一种通过对形状为(B,C,H,W)的特征图在通道维度进行归一化处理后得到形状为(B,H,W)的空间注意力映射的方法。
  • 接着将其转化为软化的概率分布:通过view操作将形状从(B,H,W)转置至(B,H\times W)后应用softmax函数进行标准化处理以生成软化的概率分布结果,并使用温度参数T来调节最终的输出。
  • 最终通过对结果再次转置并结合相关计算步骤成功生成S_{attention}映射。

C_attention(通道注意力):

  • Channels Attention(通道注意力)其主要作用是将输入特征图(feature map)中形状为(Bs,C,H,W)的数据,在高度(H)和宽度(W)两个维度上进行归一化处理,并最终生成形状为(Bs,C)的注意力权重矩阵。
    随后将其转化为更加平滑的概率分布:通过应用softmax函数对(Bs,C)维度的数据进行归一化处理,并引入温度参数T来调节锐度的变化程度。从而生成更加平滑的分类概率分布C_attention

(3) 计算focal蒸馏损失

在这里插入图片描述

其中,A^C\alpha\beta\beta\beta\beta\beta\beta\beta\beta\beta\beta\alpha}分别代表教师检测器在空间维度上的关注机制与通道级别的注意力机制,而\alpha,β则共同调节着平衡前景与背景损失之间的关系.

复制代码
    def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
        loss_mse = nn.MSELoss(reduction='sum')
        
        Mask_fg = Mask_fg.unsqueeze(dim=1)
        Mask_bg = Mask_bg.unsqueeze(dim=1)
    
        C_t = C_t.unsqueeze(dim=-1)
        C_t = C_t.unsqueeze(dim=-1)
    
        S_t = S_t.unsqueeze(dim=1)
    
        fea_t= torch.mul(preds_T, torch.sqrt(S_t))
        fea_t = torch.mul(fea_t, torch.sqrt(C_t))
        fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
        bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))
    
        fea_s = torch.mul(preds_S, torch.sqrt(S_t))
        fea_s = torch.mul(fea_s, torch.sqrt(C_t))
        fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
        bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))
    
        fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)
        bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)
    
        return fg_loss, bg_loss
  • 首先定义feature的蒸馏损失为L2损失
复制代码
    loss_mse = nn.MSELoss(reduction='sum')

随后将通道注意力模块中t时刻特征图Ct从形状大小(Bs,C)扩展至(Bs,C,1,1),同时将空间注意力模块中t时刻特征图St从形状大小(Bs,H,W)扩展至(Bs,1,H,W)。这样便于它们能够与目标预测分支输出特征图pred形状大小一致(Bs,C,H,W)进行点积运算。
接着教师分支输出特征图pred_T分别与经过开平方处理的空间注意力St以及通道注意Ct进行矩阵乘法操作从而得到经过空间注意和通道注意双重作用后的feature_map。
随后通过上述注意机制后得到的feature_map分别作用于前景掩码Mask_fg和背景掩码Mask_bg从而成功分离出前景分支特征图fg_fea_t 和背景分支特征图bg_fea_t 。
同理学生分支预测输出特征图preds_S在经历了通道注意St 和空间注意 Ct 的双重作用后生成了feature_map同样地利用Mask_fg和Mask_bg对该feature_map施加作用最终分离出student分支下的前景特征fg_fea_s 和背景部分bg_feas 。
最后计算前景分支蒸馏损失L2 即 pred_featT 与 pred_featS 的差异损失同时计算背景分支蒸馏损失L2 即 bg_featT 与 bg_featS 的差异损失

同时,在模型训练过程中,我们评估了注意力损失量L_{at}, 从而确保了mimic学生在空间维度和通道维度的空间与通道注意力映射上与教师模型的高度一致性。

在这里插入图片描述

其中ts分别代表teacher和student, l表示L1损失函数, r是平衡参数。

复制代码
    def get_mask_loss(self, C_s, C_t, S_s, S_t):
    
    mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)
    
       	return mask_loss

最终的focal蒸馏损失,是由特征损失L_{focal}和注意力损失L_{at}组成

在这里插入图片描述

2.3 Global Distillation

像像素之间所蕴含的信息具有重要的价值一样

像pixel-wise information such as pixel-wise information such as pixel-wise information such as pixel-wise information such as pixel-wise information such as pixel-wise information such as pixel-wise information

在这里插入图片描述
  • 通过GcBlock提取单张图片中的全局上下文关联信息,并让student从teacher中模仿这种关联信息。
  • 该模块源自于学生与教师在神经元层面上的特征映射提取。
  • 其损失函数l_global定义如下:
在这里插入图片描述

其中W_k, W_{u1}, W_{u2}代表卷积层的各种权重参数,并通过非线性激活函数进行激活计算。
其中LN技术代表其在神经网络中采用的标准化处理方法。
具体来说,N_p指的是在特征图中每个像素点的数量,并通过统计特征图的空间维度来计算。
\lambda则用于调节平衡损失函数的关键参数。

复制代码
    		self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)
        self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
    def spatial_pool(self, x, in_type):
        batch, channel, width, height = x.size()
        input_x = x
        # [N, C, H * W]
        input_x = input_x.view(batch, channel, height * width)
        # [N, 1, C, H * W]
        input_x = input_x.unsqueeze(1)
        # [N, 1, H, W]
        if in_type == 0:
            context_mask = self.conv_mask_s(x)
        else:
            context_mask = self.conv_mask_t(x)
        # [N, 1, H * W]
        context_mask = context_mask.view(batch, 1, height * width)
        # [N, 1, H * W]
        context_mask = F.softmax(context_mask, dim=2)
        # [N, 1, H * W, 1]
        context_mask = context_mask.unsqueeze(-1)
        # [N, 1, C, 1]
        context = torch.matmul(input_x, context_mask)
        # [N, C, 1, 1]
        context = context.view(batch, channel, 1, 1)
    
        return context
    
    
    def get_rela_loss(self, preds_S, preds_T):
        loss_mse = nn.MSELoss(reduction='sum')
    
        context_s = self.spatial_pool(preds_S, 0)
        context_t = self.spatial_pool(preds_T, 1)
    
        out_s = preds_S
        out_t = preds_T
    
        channel_add_s = self.channel_add_conv_s(context_s)
        out_s = out_s + channel_add_s
    
        channel_add_t = self.channel_add_conv_t(context_t)
        out_t = out_t + channel_add_t
    
        rela_loss = loss_mse(out_s, out_t)/len(out_s)
        
        return rela_loss

通过get_rela_loss计算Global 蒸馏损失

利用self.spatial_pool也就是上图所示的context_module模块,在通过一次1x1卷积将输入特征图从teacher_channels通道数压缩至单通道后(对应形状由(N,C,H,W)转换为(N,1,H,W)),随后利用view操作将形状转换为(N,1,H*W);接着通过应用softmax函数将其转化为概率分布矩阵,并与原始输入特征图进行矩阵乘法运算得到形状为(N,C,1,1)的学生特征图;这就是该模块的功能实现过程,并对应于spatial_pool函数。
接着经过spatial_pool处理后获得带全局上下文信息的学生特征图s_, 然后通过本模块中的自适应通道加法层(即self.channel_add_conv_s)对其进行处理:具体包括两次1x1卷积操作、一个LayerNorm层以及一个Relu激活函数;随后将计算结果与原始输入特征图进行深度wise加法运算得到输出t_
类似地也可以计算出学生网络中的输出特征s_ ; 最后分别计算t_ s_ 之间的L2蒸馏损失即可得到全局损失值。

复制代码
    self.channel_add_conv_s = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
        self.channel_add_conv_t = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))

2.4 total loss

在这里插入图片描述

原始损失 L_{original} 由检测器进行计算得出;通过提取特征图中的关键信息进行建模训练得到蒸馏损失;这一技术进一步支持了其他类型检测器在实际应用中的可行性与高效性。

复制代码
    fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, 
                   C_attention_s, C_attention_t, S_attention_s, S_attention_t)
    mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
    rela_loss = self.get_rela_loss(preds_S, preds_T)
    
    
    loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
       + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss

完整的蒸馏损失代码

  • fgd.py
复制代码
    import torch.nn as nn
    import torch.nn.functional as F
    import torch
    from mmcv.cnn import constant_init, kaiming_init
    from ..builder import DISTILL_LOSSES
    
    @DISTILL_LOSSES.register_module()
    class FeatureLoss(nn.Module):
    
    """PyTorch version of `Focal and Global Knowledge Distillation for Detectors`
       
    Args:
        student_channels(int): Number of channels in the student's feature map.
        teacher_channels(int): Number of channels in the teacher's feature map. 
        temp (float, optional): Temperature coefficient. Defaults to 0.5.
        name (str): the loss name of the layer
        alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
        beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
        gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001
        lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
    """
    def __init__(self,
                 student_channels,
                 teacher_channels,
                 name,
                 temp=0.5,
                 alpha_fgd=0.001,
                 beta_fgd=0.0005,
                 gamma_fgd=0.001,
                 lambda_fgd=0.000005,
                 ):
        super(FeatureLoss, self).__init__()
        self.temp = temp
        self.alpha_fgd = alpha_fgd
        self.beta_fgd = beta_fgd
        self.gamma_fgd = gamma_fgd
        self.lambda_fgd = lambda_fgd
    
        if student_channels != teacher_channels:
            self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
        else:
            self.align = None
        
        self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)
        self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
        self.channel_add_conv_s = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
        self.channel_add_conv_t = nn.Sequential(
            nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
            nn.LayerNorm([teacher_channels//2, 1, 1]),
            nn.ReLU(inplace=True),  # yapf: disable
            nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
    
        self.reset_parameters()
    
    
    def forward(self,
                preds_S,
                preds_T,
                gt_bboxes,
                img_metas):
        """Forward function.
        Args:
            preds_S(Tensor): Bs*C*H*W, student's feature map
            preds_T(Tensor): Bs*C*H*W, teacher's feature map
            gt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)
            img_metas (list[dict]): Meta information of each image, e.g.,
            image size, scaling factor, etc.
        """
        assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ'
    
        if self.align is not None:
            preds_S = self.align(preds_S)
        
        N,C,H,W = preds_S.shape
    
        S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)
        S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)
    
        Mask_fg = torch.zeros_like(S_attention_t)
        Mask_bg = torch.ones_like(S_attention_t)
        wmin,wmax,hmin,hmax = [],[],[],[]
        for i in range(N):
            new_boxxes = torch.ones_like(gt_bboxes[i])
            new_boxxes[:, 0] = gt_bboxes[i][:, 0]/img_metas[i]['img_shape'][1]*W
            new_boxxes[:, 2] = gt_bboxes[i][:, 2]/img_metas[i]['img_shape'][1]*W
            new_boxxes[:, 1] = gt_bboxes[i][:, 1]/img_metas[i]['img_shape'][0]*H
            new_boxxes[:, 3] = gt_bboxes[i][:, 3]/img_metas[i]['img_shape'][0]*H
    
            wmin.append(torch.floor(new_boxxes[:, 0]).int())
            wmax.append(torch.ceil(new_boxxes[:, 2]).int())
            hmin.append(torch.floor(new_boxxes[:, 1]).int())
            hmax.append(torch.ceil(new_boxxes[:, 3]).int())
    
            area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))
    
            for j in range(len(gt_bboxes[i])):
                Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
                        torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])
    
            Mask_bg[i] = torch.where(Mask_fg[i]>0, 0, 1)
            if torch.sum(Mask_bg[i]):
                Mask_bg[i] /= torch.sum(Mask_bg[i])
    
        fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg, 
                           C_attention_s, C_attention_t, S_attention_s, S_attention_t)
        mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
        rela_loss = self.get_rela_loss(preds_S, preds_T)
    
    
        loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
               + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
            
        return loss
    
    
    def get_attention(self, preds, temp):
        """ preds: Bs*C*W*H """
        N, C, H, W= preds.shape
    
        value = torch.abs(preds)
        # Bs*W*H
        fea_map = value.mean(axis=1, keepdim=True)
        S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)
    
        # Bs*C
        channel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)
        C_attention = C * F.softmax(channel_map/temp, dim=1)
    
        return S_attention, C_attention
    
    
    def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
        loss_mse = nn.MSELoss(reduction='sum')
        
        Mask_fg = Mask_fg.unsqueeze(dim=1)
        Mask_bg = Mask_bg.unsqueeze(dim=1)
    
        C_t = C_t.unsqueeze(dim=-1)
        C_t = C_t.unsqueeze(dim=-1)
    
        S_t = S_t.unsqueeze(dim=1)
    
        fea_t= torch.mul(preds_T, torch.sqrt(S_t))
        fea_t = torch.mul(fea_t, torch.sqrt(C_t))
        fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
        bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))
    
        fea_s = torch.mul(preds_S, torch.sqrt(S_s))
        fea_s = torch.mul(fea_s, torch.sqrt(C_s))
        fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
        bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))
    
        fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)
        bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)
    
        return fg_loss, bg_loss
    
    
    def get_mask_loss(self, C_s, C_t, S_s, S_t):
    
        mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)
    
        return mask_loss
     
    
    def spatial_pool(self, x, in_type):
        batch, channel, width, height = x.size()
        input_x = x
        # [N, C, H * W]
        input_x = input_x.view(batch, channel, height * width)
        # [N, 1, C, H * W]
        input_x = input_x.unsqueeze(1)
        # [N, 1, H, W]
        if in_type == 0:
            context_mask = self.conv_mask_s(x)
        else:
            context_mask = self.conv_mask_t(x)
        # [N, 1, H * W]
        context_mask = context_mask.view(batch, 1, height * width)
        # [N, 1, H * W]
        context_mask = F.softmax(context_mask, dim=2)
        # [N, 1, H * W, 1]
        context_mask = context_mask.unsqueeze(-1)
        # [N, 1, C, 1]
        context = torch.matmul(input_x, context_mask)
        # [N, C, 1, 1]
        context = context.view(batch, channel, 1, 1)
    
        return context
    
    
    def get_rela_loss(self, preds_S, preds_T):
        loss_mse = nn.MSELoss(reduction='sum')
    
        context_s = self.spatial_pool(preds_S, 0)
        context_t = self.spatial_pool(preds_T, 1)
    
        out_s = preds_S
        out_t = preds_T
    
        channel_add_s = self.channel_add_conv_s(context_s)
        out_s = out_s + channel_add_s
    
        channel_add_t = self.channel_add_conv_t(context_t)
        out_t = out_t + channel_add_t
    
        rela_loss = loss_mse(out_s, out_t)/len(out_s)
        
        return rela_loss
    
    
    def last_zero_init(self, m):
        if isinstance(m, nn.Sequential):
            constant_init(m[-1], val=0)
        else:
            constant_init(m, val=0)
    
    
    def reset_parameters(self):
        kaiming_init(self.conv_mask_s, mode='fan_in')
        kaiming_init(self.conv_mask_t, mode='fan_in')
        self.conv_mask_s.inited = True
        self.conv_mask_t.inited = True
    
        self.last_zero_init(self.channel_add_conv_s)
        self.last_zero_init(self.channel_add_conv_t)

3. 实验

FGD蒸馏实验通过调节α、β、γ、λ参数来协调前景与背景之间的损失(见公式9-11),其中T=0.5用于优化注意力分配。在所有two-stage检测器中,grass参数被设定为α=5×10⁻⁵, β=2.5×10⁻⁵, γ=5×10⁻⁵, λ=5×10⁻⁷;而anchor base one-stage检测器采用了不同的超参数配置:α=1×10⁻³, β=5×10⁻⁴, γ=1×10⁻³, λ=5×10⁻⁶;最后 anchor free one-stage检测器则采用了更为激进的配置:α=1.6×10⁻³, β=8×10⁻⁴, γ=8×10⁻³, λ=8×10⁻⁶。

在这里插入图片描述

我们在基于锚框的方法与完全锚框方法的单靶点检测系统与双靶点检测系统上实施了对比实验,在COCO2017评估框架下实现学生目标检测系统的显著AP和AR提升。

我们采用了具有更强大的检测模块对学生进行蒸馏,在实验中发现当以更强的模型作为教师执行蒸馏操作时能够实现更好的性能提升效果。具体而言,在ResNet-101与ResNeXt-101两位老师的蒸馏过程中中训练得到RetinaNet-R50分别达到了39.7与40.7的mAP值。

采用基于FGD的知识蒸馏方法训练后生成的学生模型,在进行一次完整的推理过程后仍需进一步优化以提高其准确性与稳定性。随后我们对模型中的注意力机制进行了详细分析。观察发现,在空间注意力分布方面与教师模型高度一致,在通道注意力分布方面则呈现一定的差异性特征。这种差异性可能源于数据集划分的影响因素综合作用所致。观察发现,在空间注意力分布方面与教师模型高度一致,在通道注意力分布方面则呈现一定的差异性特征。这种差异性可能源于数据集划分的影响因素综合作用所致。这表明学生模型成功模仿了教师的行为模式,并提取了更具代表性的特征从而显著提升了模型在相关任务上的性能指标

在这里插入图片描述

完整代码

链接:https://pan.baidu.com/s/1V7H3GjeEBxqMK0dGx90KDw?pwd=4vgn
提取码:4vgn

github: https://github.com/yzd-v/FGD

全部评论 (0)

还没有任何评论哟~