大模型知识蒸馏核心技术(4)—— 关系型知识蒸馏
发布时间
阅读量:
阅读量
版权声明
- 本文原创作者:谷哥的小弟
- 作者博客地址:

大模型知识蒸馏的核心技术旨在将教师模型的知识高效迁移至学生模型。其中,样本间关系建模(RKD,Relational Knowledge Distillation)是一种重要的方法,它通过迁移样本间的距离或角度关系,增强学生模型的泛化能力。
RKD 的核心思想
RKD 认为样本之间的关系是一种更高级别的信息,这种关系信息比单个样本的输出信息更有助于学生模型的学习。具体来说,RKD 通过以下两种损失函数来实现样本间关系的迁移:
1. 距离蒸馏损失(Distance-wise Distillation Loss)
距离蒸馏损失用于匹配教师模型和学生模型输出特征之间的距离关系。具体公式如下:

2. 角度蒸馏损失(Angle-wise Distillation Loss)
角度蒸馏损失用于匹配教师模型和学生模型输出特征之间的角度关系。具体公式如下:

RKD 的优势
- 增强泛化能力 :通过迁移样本间的距离和角度关系,学生模型能够学习到更丰富的结构化信息,从而在未见过的数据上表现更好。
- 适应不同维度 :即使教师模型和学生模型的输出维度不同,RKD 仍然可以通过距离和角度关系进行有效的知识迁移。
- 简单高效 :RKD 的损失函数设计简单,易于实现,并且在多个任务上都取得了显著的效果。
RKD 的应用场景
RKD 在多个领域都有广泛的应用,包括但不限于:
- 图像分类 :通过迁移样本间的距离和角度关系,学生模型能够更好地学习到图像特征的结构化信息,从而提高分类准确率。
- 度量学习 :在度量学习任务中,RKD 可以帮助学生模型学习到更有效的特征表示,使得相似样本之间的距离更接近,不相似样本之间的距离更远。
- 少样本学习 :在少样本学习任务中,RKD 可以通过迁移教师模型的结构化知识,帮助学生模型在有限的数据上更好地泛化。
RKD 的实现代码
以下是一个简单的 RKD 损失函数的实现代码(使用 PyTorch):
import torch
import torch.nn as nn
import torch.nn.functional as F
class RKDLoss(nn.Module):
def __init__(self, w_d=25, w_a=50):
super(RKDLoss, self).__init__()
self.w_d = w_d
self.w_a = w_a
def forward(self, f_s, f_t):
student = f_s.view(f_s.shape[0], -1)
teacher = f_t.view(f_t.shape[0], -1)
# RKD distance loss
with torch.no_grad():
t_d = self.pdist(teacher, squared=False)
mean_td = t_d[t_d > 0].mean()
t_d = t_d / mean_td
d = self.pdist(student, squared=False)
mean_d = d[d > 0].mean()
d = d / mean_d
loss_d = F.smooth_l1_loss(d, t_d)
# RKD Angle loss
with torch.no_grad():
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
norm_td = F.normalize(td, p=2, dim=2)
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, dim=2)
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
loss_a = F.smooth_l1_loss(s_angle, t_angle)
loss = self.w_d * loss_d + self.w_a * loss_a
return loss
@staticmethod
def pdist(e, squared=False, eps=1e-12):
e_square = e.pow(2).sum(dim=1)
prod = e @ e.t()
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)
if not squared:
res = res.sqrt()
res = res.clone()
res[range(len(e)), range(len(e))] = 0
return res
python

总结
样本间关系建模(RKD)通过迁移样本间的距离和角度关系,为知识蒸馏提供了一种新的视角。它不仅能够增强学生模型的泛化能力,还能够适应不同维度的教师和学生模型。RKD 在多个任务中都取得了显著的效果,是一种非常有前景的知识蒸馏方法。
全部评论 (0)
还没有任何评论哟~
