模型压缩-之知识蒸馏与迁移学习
模型压缩大体上可以分为 5 种:
- 模型剪枝:即移除对结果作用较小的组件,如减少 head 的数量和去除作用较少的层,共享参数等,ALBERT属于这种;
- 量化:比如将 float32 降到 float8;
- 知识蒸馏:将 teacher 的能力蒸馏到 student上,一般 student 会比 teacher 小。我们可以把一个大而深的网络蒸馏到一个小的网络,也可以把集成的网络蒸馏到一个小的网络上。
- 参数共享:通过共享参数,达到减少网络参数的目的,如 ALBERT 共享了 Transformer 层;
- 参数矩阵近似:通过矩阵的低秩分解或其他方法达到降低矩阵参数的目的;
大模型往往是单个复杂网络或者是若干网络的集合,拥有良好的性能和泛化能力。无法在移动端或嵌入式设备上。因此需要对模型进行压缩,且知识蒸馏是模型压缩中重要的技术之一。
知识蒸馏核心思想 是采取Teacher-Student模式:将复杂且大的模型作为Teacher,Student模型结构较为简单,用Teacher来辅助Student模型的训练,Teacher学习能力强,可以将它学到的知识迁移给学习能力相对弱的Student模型,以此来增强Student模型的泛化能力。复杂笨重但是效果好的Teacher模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的Student小模型。

学生模型怎么选择
1.选择模型结构:
- 简化版本:学生模型通常是教师模型的简化版本。例如,如果教师模型是一个深度卷积神经网络(CNN),学生模型可以是一个浅层的CNN。
- 不同架构:学生模型也可以选择不同于教师模型的架构。比如教师模型是Transformer,学生模型可以是LSTM或者较浅的CNN。
- 架构相似性:学生模型的架构最好与教师模型有一定的相似性,这有助于更好地迁移知识。例如,如果教师模型是卷积神经网络(CNN),学生模型也可以选择类似的CNN结构。
2. 模型类型选择
- 浅层网络 :对于深度较大的教师模型,可以选择较浅的网络作为学生模型。例如,ResNet50作为教师模型时,可以选择ResNet18或MobileNet作为学生模型。
- 专门优化的模型 :一些专门为高效推理设计的模型(如EfficientNet、SqueezeNet、ShuffleNet等)也是不错的选择,它们在保证一定准确率的同时大幅减少了参数量和计算量。
经测试,最好是同结构蒸馏效果比较好,如yolov5s去蒸馏yolov5n。
具体的蒸馏过程
定义教师模型和学生模型 :首先,定义一个参数量较大、学习能力强的教师模型(Teacher Model),和一个参数量较小、学习能力较弱的学生模型(Student Model)。
训练教师模型 :在原始数据集上训练教师模型,使其达到最佳性能。
教师模型生成软标签(Soft Labels)
-
硬标签(Hard Labels) :传统分类任务中,每个样本对应一个具体的类别标签(如0或1)。
-
软标签(Soft Labels) :教师模型输出的概率分布,表示每个类别的置信度。这些概率值通常是经过softmax函数处理后的结果。

-
Hard-target :原始数据集标注的 one-shot 标签,除了正标签为 1,其他负标签都是 0。
-
Soft-target :Teacher模型softmax层输出的类别概率,每个类别都分配了概率,正标签的概率最高。
在教师模型的输出上应用一个温度参数T的softmax函数,以生成软标签(soft targets)。温度参数T的作用是平滑预测概率分布,使得输出概率分布更平缓。较高的T值会使概率分布更加均匀,有助于学生模型更好地学习到教师模型的不确定性信息。
- 抑制过拟合 : 高蒸馏温度下的软目标概率分布更平滑,相比硬目标更容忍学生模型的小误差。这有助于防止学生模型在训练过程中对教师模型的一些噪声或细微差异过度拟合,提高了模型的泛化能力。
- 降低标签噪声的影响 : 在训练数据中存在标签噪声或不确定性时,平滑的软目标可以减少这些噪声的影响。学生模型更倾向于关注教师模型输出的分布,而不是过于依赖单一的硬目标。
- 提高模型鲁棒性 : 平滑的软目标有助于提高模型的鲁棒性,使其对输入数据的小变化更加稳定。这对于在实际应用中面对不同环境和数据分布时的模型性能至关重要。
损失函数 :学生模型的损失函数由两部分组成:一部分是真实标签的交叉熵损失,另一部分是教师模型输出的软标签与学生模型输出的交叉熵损失,两者通过一个权重因子α进行加权。
离线蒸馏
教师模型在学生模型训练之前已经完成训练,并且其参数在整个蒸馏过程中保持不变。这种方法是大部分知识蒸馏算法采用的方法,主要包含三个过程:
- 首先,教师模型在大规模数据集上进行训练,达到理想的性能水平。这个过程通常耗时较长且需要大量计算资源;
- 知识提取:将教师模型的知识提取出来,通常以教师模型对训练数据的输出(如概率分布或特征表示)的形式表示;
- 学生模型的训练:在学生模型的训练过程中,使用教师模型的输出作为指导。学生模型通过一个额外的蒸馏损失函数,学习如何模拟教师模型的输出。常见的蒸馏损失函数包括交叉熵损失和均方误差损失。
该方法主要侧重于知识迁移部分,教师模型通常参数量大,一些庞大复杂模型会通过这种方式得到较小模型,比如 BERT 通过蒸馏学习得到 tinyBERT。 它的主要优点在于能灵活选择预训练好的大型模型作教师,在蒸馏过程中教师模型不需要参数更新,而只需要关注学生模型的学习,这使得训练过程的部署简单可控,大大减少了知识蒸馏的资源消耗和成本,但这种方法的缺点是学生模型非常依赖教师模型。
在线知识蒸馏
教师模型和学生模型在同一训练过程中共同学习。教师模型不再是预先训练好的,而是与学生模型同步更新,教师模型和学生模型相互影响,共同提升性能,相互学习和调整。这种协同学习使得教师模型和学生模型可以动态适应数据变化和任务需求。
在线知识蒸馏能够在没有预训练模型的情况下,针对不同任务实现知识学习和蒸馏,有助于多个模型在学习过程中互相调整和更新学到的知识,实现优势互补。特别是对于多任务学习等特殊场景,具有很大优势。相比于模型压缩,在线学习更适合于知识融合以及多模态、跨领域等场景 。然而训练过程中,增加的模型数量可能会导致计算资源的消耗增加。
自蒸馏
是一种比较特殊的知识蒸馏模式,可以看作是的一种特例,即教师模型和学生模型采用相同的网络模型的在线蒸馏。自蒸馏过程中,学生模型从自身的输出中进行学习,这意味着学生模型将深层的信息传递给浅层,以指导自身的训练过程,而无需依赖外部的教师模型。用学习过程比喻,离线蒸馏是知识渊博的老师向学生传授知识;在线蒸馏是老师和学生一起学习、共同进步;自蒸馏是学生自学成才。
自蒸馏的提出主要是为了解决传统两阶段蒸馏方法的一些问题。传统方法需要预先训练大型教师模型,这会消耗大量的时间和计算资源。而且,教师模型和学生模型之间可能存在能力不匹配的问题,导致学生无法有效地学习教师模型的表征。
自蒸馏方法克服了这些问题,它不需要依赖教师模型进行指导,而是通过学生模型自身的输出来提升性能。这种方法使得学生模型能够在没有外部指导的情况下自我提升,并且可以更加高效地进行模型训练。
例子讲解

1. teacher模型训练,这里不做详细说明
2. teacher模型输出,
logits 输出 ,什么是Logits ,我们知道,对于一般的分类问题,比如图片分类,输入一张图片后,经过DNN网络各种非线性变换,在网络最后Softmax层之前,会得到这张图片属于各个类别的大小数值z i,某个类别的z i 数值越大,则模型认为输入图片属于这个类别的可能性就越大。这些汇总了网络内部各种信息后,得出的属于各个类别的汇总分值 z i ,就是Logits,i代表第i个类别,z i 代表属于第i类的可能性。因为Logits并非概率值,所以一般在Logits数值上会用Softmax函数进行变换,得出的概率值作为最终分类结果概率。Softmax一方面把Logits数值在各类别之间进行概率归一,使得各个类别归属数值满足概率分布;另外一方面,它会放大Logits数值之间的差异,使得Logits得分两极分化,Logits得分高的得到的概率值更偏大一些,而较低的Logits数值,得到的概率值则更小。
在高温 T 下生成 soft target 。softmax 函数计算方法如下:

其中 q i 表示不同类别的预测概率。这个预测结果是 soft target,而真实目标是 hard target,一般机器学习的目标就是让 soft target 逼近 hard target。
引入“蒸馏”的概念,在上式基础上添加一个温度系数 T,如果想让soft target更加平缓,高的降低,低的升高。 这时就要对soft target使用蒸馏温度。 让soft target更soft。 :

当 T = 1 时,就是标准的 softmax 函数。T 越大,得到的概率分布的熵越大,负标签携带的信息会被放大,负标签的概率分布会对损失函数有更明显的影响,模型训练会更关注这部分信息。

3. 学生训练


上面是已经训练好的教师网络。 把数据输入到教师网络,在输出时使用蒸馏温度为T的softmax。再把数据输入到学生网络,学生网络可能是还没有训练的网络,也可能是训练一半的半成品网络。
关于数据集的使用
通常情况下,学生模型会使用与教师模型相同的训练数据集,这是因为教师模型已经在该数据集上进行了充分的训练,能够提供高质量的软标签(Soft Labels),这些软标签包含了比硬标签(Hard Labels)更多的信息,有助于学生模型更好地学习。
在某些情况下,学生模型也可以使用与教师模型不同的数据集进行训练。这通常发生在以下几种场景:
数据集差异 :如果学生模型需要适应不同的数据分布或领域,可以使用不同的数据集进行训练,同时利用教师模型提供的知识(软标签)来辅助学习。
数据增强 :通过数据增强技术生成新的训练样本,学生模型可以在这些增强样本上进行训练,同时利用教师模型在原始数据集上的知识。
迁移学习 :在迁移学习中,学生模型可能需要在不同的任务或领域上进行训练,这时可以使用与教师模型不同的数据集,但仍然利用教师模型的知识来提高学习效率。
在某些高级应用中,学生模型可能会结合使用与教师模型相同的数据集和不同的数据集。例如,学生模型可以在与教师模型相同的数据集上进行初步训练,然后在不同的数据集上进行微调,以适应特定的任务或领域。
损失函数
学生网络既要在蒸馏温度等于T时与教师网络的结果相接近。也要保证不使用蒸馏温度时的结果与真实结果相接近。
- 蒸馏损失 : 把教师网络使用蒸馏温度为t的输出结果 与 学生网络蒸馏温度为t的输出结果做损失。 让这个损失越小越好。
- 学生损失 : 学生网络蒸馏温度为1(即不使用蒸馏网络)时的预测结果和真实的标签做loss。
最后对这两项加权求和。
import torch.nn.functional as F
def knowledge_distillation_loss(student_outputs, teacher_outputs, labels, temperature, alpha):
"""
计算知识蒸馏损失
:param student_outputs: 学生模型的输出
:param teacher_outputs: 教师模型的输出
:param labels: 真实标签
:param temperature: 温度参数
:param alpha: 平衡因子
:return: 总损失
"""
# 计算学生模型的硬标签损失
hard_labels_loss = F.cross_entropy(student_outputs, labels)
# 计算软标签损失
soft_teacher_outputs = F.softmax(teacher_outputs / temperature, dim=1)
soft_student_outputs = F.log_softmax(student_outputs / temperature, dim=1)
soft_labels_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction='batchmean')
# 总损失
total_loss = alpha * hard_labels_loss + (1 - alpha) * temperature**2 * soft_labels_loss
return total_loss
python

学生模型训练实现:
# 定义超参数
temperature = 4.0
alpha = 0.5
num_epochs = 10
batch_size = 64
learning_rate = 0.001
# 数据加载器
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 优化器
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)
# 训练循环
for epoch in range(num_epochs):
student_model.train()
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad()
# 获取教师模型的输出
with torch.no_grad():
teacher_outputs = teacher_model(images)
# 获取学生模型的输出
student_outputs = student_model(images)
# 计算损失
loss = knowledge_distillation_loss(student_outputs, teacher_outputs, labels, temperature, alpha)
# 反向传播与优化
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
python

评估学生模型代码实现:
在验证集上评估学生模型的性能,确保其继承了教师模型的主要特征。
# 测试数据加载器
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
def evaluate(model, dataloader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in dataloader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
return accuracy
# 评估学生模型
student_accuracy = evaluate(student_model, test_loader)
print(f'Student Model Accuracy: {student_accuracy:.2f}%')
python

微调
如果需要进一步提升性能,可以在少量真实标签数据上对学生模型进行微调。
# 假设我们有一些微调数据
fine_tune_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
fine_tune_loader = DataLoader(fine_tune_dataset, batch_size=batch_size, shuffle=True)
# 微调循环
for epoch in range(num_epochs):
student_model.train()
running_loss = 0.0
for images, labels in fine_tune_loader:
optimizer.zero_grad()
# 获取教师模型的输出
with torch.no_grad():
teacher_outputs = teacher_model(images)
# 获取学生模型的输出
student_outputs = student_model(images)
# 计算损失
loss = knowledge_distillation_loss(student_outputs, teacher_outputs, labels, temperature, alpha)
# 反向传播与优化
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Fine-tune Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(fine_tune_loader):.4f}')
# 再次评估学生模型
student_accuracy = evaluate(student_model, test_loader)
print(f'Fine-tuned Student Model Accuracy: {student_accuracy:.2f}%')
python

