论文笔记:SPATIO-TEMPORAL STRUCTURE CONSISTENCY FOR SEMI-SUPERVISED MEDICAL IMAGE CLASSIFICATION (ICASSP)
论文链接:https://arxiv.org/pdf/2303.01707
论文代码:暂无
写在前面:
在看论文和查阅资料发现写医学图像半监督学习方面的博文不是很多,所以就写下了这篇。花了一个周末仔细读了这篇论文,这也是我第一次写blog,如果有什么不对的地方请多多指教。
目录
Abstract
Introduction
Method
Spatio-Temporal Structure Consistent Framework
Spatial and Temporal Structure Consistency
Temporal Sub-structure Consistency
Experiment
Comparison Results
Ablation Studies
Conclusion
个人总结
Abstract
基于全标注的大规模数据集的智能医疗诊断已经取得了显著的进展。然而,由于专家的注释非常昂贵,很少有带标签的图像可用。为了有效利用大量的未标记数据,我们提出了一种新的时空结构一致性学习框架,将时空结构一致性结合起来。具体来说,我们推导了一个gram矩阵来捕捉不同训练样本在表示空间中的结构相似性。在空间层面上,我们的框架明确地增强了扰动下不同样本之间结构相似性的一致性。在时间层次上,我们希望通过挖掘关系图中稳定的子结构,在不同的训练迭代中保持结构相似度的一致性。 在两个医学图像数据集(ISIC 2018和ChestX-ray14)上的实验表明,我们的方法优于最先进的半监督学习方法。此外,本文还利用Grad-CAM对Gram矩阵和热图进行了广泛的定性分析,以验证本文方法的有效性。
Introduction
深度学习在大规模标注的医学图像分析中取得成功,然而获得大量准确的医学图像标注是困难的,所以半监督学习(SSL)的方法被应用于提高性能。半监督可以分为self-training和consistency regularization,后者是SSL的关键方法之一。目前方法主要通过研究空间结构关系来利用未标记样本,缺乏对时间一致性的有效探索。
从医疗人员经常查阅之前的样本帮助诊断得到启发,本文提出了一种新的时空结构一致(STSC)半监督框架(见图1),以同时探索不同样本之间的时空结构关系。特别地,本文推导出一个gram matrix来描述不同样本在表示空间中的相似性。然后将训练样本的图结构转化为邻接矩阵来表示。在训练过程中,通过激励gram matrix在不同扰动下的一致性,可以获得稳定的空间结构。在训练后期,提出了一种时间子结构一致性(TSC)方法来保持结构关系的时间一致性,并在训练过程中进一步捕获关系图中稳定的子结构,从无标记数据之间的关系中学到更多有区别的语义信息。

(补充资料——gram matrix :格拉姆矩阵(Gram matrix)详细解读 - 知乎)
主要贡献:
(1) 提出了一种新的STSC半监督学习框架,有效地利用无标签数据,降低了单标签和多标签任务对有标签数据的需求。
(2) 提出了一种时序子结构一致性(TSC)方法来研究关系图中的稳定子结构。在训练的同时,可以有效地捕捉稳定的样本结构。
(3) 在两个公共医学图像数据集(ISIC 2018和ChestX-ray14)上进行的实验表明,与最先进的(SOTA)方法相比具有更好的性能。
Method
Spatio-Temporal Structure Consistent Framework
半监督学习的总损失函数:

其中,Ls表示标记集DL上的监督损失(交叉熵损失),Lu为迫使相同输入在不同扰动下的一致预测的无监督一致性损失。本文利用相同主干的师生结构,分别以θ和θ'作为参数。η和η'表示应用于同一输入图像的不同摄动。λ是一个超参数,它控制监督损失和非监督损失之间的权衡。
教师模型的权重θ'被更新为学生模型权重θ的指数移动平均(EMA):

其中,α是控制更新速率的超参数。
Consistency Loss:

以教师模型预测概率和学生模型预测概率之间的均方误差作为损失函数。
def softmax_mse_loss(input_logits, target_logits):
"""Takes softmax on both sides and returns MSE loss
4. Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to inputs but not the targets.
"""
# input是学生模型全连接层输出概率; target是教师模型全连接层的输出概率
assert input_logits.size() == target_logits.size()
input_softmax = F.softmax(input_logits, dim=1)
target_softmax = F.softmax(target_logits, dim=1)
mse_loss = (input_softmax-target_softmax)**2 * CLASS_WEIGHT
return mse_loss
(补充资料:
EMA: 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现 - 知乎
Mean Teacher 论文 : https://proceedings.neurips.cc/paper/2017/file/68053af2923e00204c3ca7c6a3150cf7-Paper.pdf)
Spatial and Temporal Structure Consistency
首先,根据网络最后的feature map得到Dl矩阵,形状为[Batch, HWC],对其计算Gram Matrix得到Ml矩阵:

然后对Ml矩阵的每一行进行归一化得到,final sample relation matrix Rl:

最后分别对学生和教师模型计算均方误差,得出Spatial structure consistency loss:

def relation_mse_loss(activations, ema_activations):
"""Takes softmax on both sides and returns MSE loss
4. Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to inputs but not the targets.
"""
# activations和ema_activations分别是是学生和教师模型经过FC之前最后的feature map
assert activations.size() == ema_activations.size()
activations = torch.reshape(activations, (activations.shape[0], -1))
ema_activations = torch.reshape(ema_activations, (ema_activations.shape[0], -1))
similarity = activations.mm(activations.t()) # Gram Matrix
norm = torch.reshape(torch.norm(similarity, 2, 1), (-1, 1))
norm_similarity = similarity / norm # normalization
ema_similarity = ema_activations.mm(ema_activations.t()) # Gram Matrix
ema_norm = torch.reshape(torch.norm(ema_similarity, 2, 1), (-1, 1))
ema_norm_similarity = ema_similarity / ema_norm # normalization
similarity_mse_loss = (norm_similarity-ema_norm_similarity)**2
return similarity_mse_loss
(补充资料—— SRC-MT论文: https://arxiv.org/pdf/2005.07377)
Temporal Sub-structure Consistency
首先通过一个阈值对Gram Matrix R二值化得到邻接矩阵 A:

然后由邻接矩阵A产生图G(X, E),X是样本集合,E是边的集合,下一个步骤我不是特别熟悉,所以直接翻译:
训练过程中,我们识别出G中稳定的子结构s(X¯),其中X¯中的所有元素在t和t + 1时间内都是连接的。稳定子结构集合表示为S = {si}k i=1,其中si为第i个稳定子结构,k为稳定子结构的数量。
最后,对前一训练迭代和本次训练迭代的稳定子结构计算Temporal structure consistency loss:

(时间结构一致性这一部分是本文的一个创新点,我对Graph这部分不是特别熟悉,所以我暂时还不知道怎么在代码上进行体现,怎么去算它的稳定子结构,有知道的朋友多多指教啦~)
总的损失函数:

跟上面那个半监督学习的总损失函数一样,这里展示了Lu的组成。
Experiment
评价指标:AUC、准确性、敏感性和特异性
Comparison Results

从表1可以看出,self-training方法比其他方法具有更高的特异性,这得益于阴性样本。与之前的SOTA方法SRC-MT相比,本文的方法在AUC、准确度和灵敏度方面分别优于3.58%、1.78%和7.10%。值得注意的是,STSC通过增强不同样本之间时空关系的一致性,在SRC-MT上实现了所有度量的改进,表明了方法的有效性。

在表2中,GraphXNET是基线模型,它的性能与不同的标签数据百分比有很大的差异。随着标记数据百分比的改变,本文的方法表现出更稳定的性能。SRC-MT是之前的SOTA方法。此外,在标记数据设置为2%和5%时,STSC方法的AUC低于SRC-MT方法,在标记数据百分比增加时表现出更好的性能。这种现象可能是由于标记的数据受益于更可靠的关系结构。

由Grad-CAM进行的可视化显示如图2所示,其中前两行是ISIC 2018数据集。我们的模型得到的注意区域与医生经验得到的病变区域是一致的,如前两排的注意地图几乎都与病变区域重叠。此外,在最后两行注意力图中,突出了症状明显的胸部区域。

Ablation Studies

在表3中,研究了不同百分比标记数据的影响。本文的方法在10%、20%、30%和50%标记数据的情况下,可以获得优于基线和SOTA方法的性能。此外,STSC在只有20%标记数据的情况下可以获得更高的AUC和准确率,接近100%标记数据训练的上限。使用20%标记数据训练的模型与使用50%标记数据训练的监督基线模型的性能相当,这进一步验证了方法的有效性。

不同一致性项的消融研究如表4所示。当只使用监督损失项时,模型表现不佳。每个无监督损失项都可以显著提高最终性能,特别是对于AUC。此外,非监督损失条款的不同组合也能很好地发挥作用。例如,时间一致性和空间一致性损失都可以获得更高的AUC。当将所有无监督损失项组合在一起时,该模型在所有指标中进一步实现了最佳性能。
Conclusion
本工作研究半监督医学图像分类,以减轻对标记数据的需要,用来训练深度神经网络。我们提出了一种考虑样本时空结构稳定性的STSC框架。在两个公共基准医学图像分类数据集上进行了大量的实验,以证明我们的方法在单标签和多标签医学图像分类任务上的有效性。最后给出了可视化结果,验证了该方法的有效性。
个人总结
本文是基于mean teacher 框架的改进,其中无监督约束中的Consistency loss和Spatial structure consistency loss在SRC-MT中提出,所以本文的主要贡献提出了Temporal structure consistency loss。Ltc与Lsc类似先对feature map 计算相似性,然后结合图结构在时间上对相似性进行约束。终于写完了自己的第一篇博文啦,有啥不对的地方请及时指出。

