【提示学习论文】KDPL:Improving Zero-shot Generalization of Learned Prompts via Unsupervised Knowledge Distil
发布时间
阅读量:
阅读量
Improving Zero-shot Generalization of Learned Prompts via Unsupervised Knowledge Distillation(ECCV 2024)
- 利用无监督知识蒸馏改进学习提示的zero-shot泛化
- image+text
- 佛罗伦萨大学、比萨大学
- 代码:https://github.com/miccunifi/KDPL
1 KDPL

教师模型
- 冻结的text encoder:输入 a photo of a class,得到文本特征ψi,T,B
- 冻结的image encoder:输入图像,得到图像特征ψI,B
- 计算教师概率pT

学生模型
- text encoder:输入class+learnable prompts,得到ψi,T,S
- image encoder:输入图像和prompts,得到ψI,S
- 计算学生概率pS

知识蒸馏
将教师模型概率pT与学生模型的概率pS进行对比,通过对称KL散度损失函数进行知识蒸馏,更新学生模型的提示γ
2 标签不可知的提示学习
1 KDPL overview
可以在没有类别名称或标签信息的情况下,与任意现有的提示学习方法无缝集成。
- 标签不可知:不使用真实标签,但假设知道训练数据集中的类别名称
- 类别不可知:更进一步,假设训练类别名称也是不可知的。此时从包含大约20k个类别名称的大词典中自动筛选类别。
2 KDPL
- 训练过程中不使用图像的真实标签,但我们知道训练数据集中存在的类别名称。
- 例如我们知道fish、cat、cow
- 使用教师模型进行zero-shot分类,计算每个类别的概率pT
- 学生模型计算已知类别的概率分布pT
- 通过对齐KL散度损失函数优化学生模型提示
怎么对齐的?损失函数计算的部分,比如fish的与哪个对齐??
老师和学生模型都输入了训练集类别名称,一样的class进行对齐就好。
使用KL散度,计算老师预测概率与学生预测概率之间的损失:
![[KDPLg3.png]]
- 蒸馏损失的计算,取决于老师的固定预测、学生的即时预测和类别集合C。
- KDPL可以用于标签不可知和类别不可知的适应场景。我们实验发现,对称KL散度略优于任何一种非对称选项。
3 类别不可知的提示学习
我们不知道训练数据集中所有类别的名称。
从一个大型词汇表Open Images V7 dataset(包含20k类别)中自动选择每个批次最相关的类别。
给定一个图像批次X=Ii N,和所有类别名称C,使用教师模型对所有图像和词汇表中的所有类别(20k)进行推理,得到概率pT
生成概率矩阵PT:对于图像批次的每张图,计算每个类别的概率,沿着批次维度堆叠,得到概率矩阵:
沿批次轴计算平均概率,得到
,表示每个类别在整个批次中的平均概率
根据平均概率,选择K个最高类别,作为学生模型的输入
4 实验
实现细节:
- CLIP ViT-H-14作为老师模型
- ViT-B/32作为学生模型
- K=1000
域泛化:

全部评论 (0)
还没有任何评论哟~
