【KD】Correlation Congruence for Knowledge Distillation
Paper: Correlation Congruence for Knowledge Distillation
1, Motivation:
通常情况下KD的teacher模型的特征空间没考虑类内类间的分布,student模型也将缺少我们期望的类内类间的分布特性。
Usually, the embedding space of teacher possesses the characteristic that intra-class instances cohere together while inter-class instances separate from each other. But its counterpart of student model trained by instance congruence would lack such desired characteristic.
****
2,Contribution:
- 提出相关一致性知识蒸馏(CCKD),它不仅关注实例一致性,而且关注相关一致性。(instance congruence通过mini-batch的PK或聚类实现。correlation congruence通过样本I,J直接的相关性损失函数的约束实现实现。)
- 将mini-batch中的相关性计算直接转成mini-batch的的大矩阵进行,减少计算量。
- 采用不同的mini-batch sampler strategies.
- 在CIFAR-100, ImageNet-1K, person reidentification and face recognition进行实验。
3,论文框架:

3.3. Correlation Congruence
相关一致性知识蒸馏
-
提取特征

-
映射embedding feature space

映射函数 can be any correlation metric, and we will introduce three metric for capturing the correlation between instances in next section.
3,计算 correlation matrix

相关一致性:公式

Gaussian RBF is more flexible and powerful in capturing the complex non-linear relationship between instances.(论文最后采用高斯kernel计算相关性,但计算量真的很大。。)
LOSS FUCTION:

(比传统的KD多了一个相关一致性的损失函数约束)
4,实验结果 :



可以看到加约束的,intra-class距离更大
5,Setting:
On CIFar-100, ImageNet-1K and MSMT17, Original Knowledge distillation (KD) [15] and cross-entropy (CE) are chosen as the baselines. For face recognition, ArcFace loss [5] and L 2-mimic loss [21, 23] are adopt. We compare CCKD with several state-of-the-art distillation related methods, including attention transfer (AT) [37], deel mutual learning (DML) [39] and conditional adversarial network (Adv) [35]. For attention transfer, we add it for last two blocks as suggested in [37]. For adversarial training, the discriminator consists of FC(128 × 64) + BN + ReLU + FC (64 × 2) + Sigmoid activation layers, and we adopt BinaryCrossEntropy loss to train it.
ResNet-50 is used as the teacher network and ResNet-18 as student network. The dimension of the feature representation is set to 256. We set the weight decay to 5 e - 4, batch size to 40, and use stochastic gradient descent with momentum. The learning rate is set as 0.0003, then divided by 10 at 45, 60 epochs, totally 90 epochs.
