半监督学习之DTC(Semi-supervised Medical Image Segmentation through Dual-task Consistency)
半监督学习之DTC
不同于MixMatch这类使用“数据增强后的结果一致性(consistency)”,改方法使用“任务一致性”来约束模型(正则化)。由于第一类方法的无监督信号的构建需要模型的预测流程,即每一步由“训练+预测”构成,所以相对的带来了训练时间的大大增加(在Keras框架里面可以通过构建类似GAN一样的结构,即将一个模型complie两次,但是预测的哪个模型的所有参数的设置为不可训练)。而DTC(
开创性的工作)提出一种新思路——“任务一致性正则”,通过将分割分为pixel-wise分类任务和level-set(水平集)函数回归任务(标签是通过一个符号函数转换后的图片)。在这里水平集函数回归思想比较妙,利用了神经网络的本质是一个“万能的函数逼近器”的概念。
skimage level set
Module: segmentation - skimage v0.19.0.dev0 docs
笔记:
Notion
基于双任务一致性的半监督医学图像分割
不同于使用通过扰动数据和网络来规范(regularize)模型训练,使用任务正则化。
一个是水平集的分割图和直接的分割
半监督学习框架通过直接从有限数量的标记数据和大量的未标记数据中学习,获得高质量的分割结果。
现有方法都是通过损失中的正则化项来加强未标记数据预测的一致性。
pixel-wise 和shape-aware多任务
不同任务的结果应该映射、转换到同一个预定义的空间
建立水平集的回归任务和像素点的分类任务的一致性。
分为3部分:
1.第一部分是双任务分割网络:
将分割任务建模为两个问题:1.预测一个像素分类图;2.获得一个全局级水平集函数,其中零级是分割轮廓
2.第二部分是将水平集函数转换为一个分割的概率图
3.第三部分混合监督和无监督的损失函数。这个可以加速全监督的学习也可以有效地利用未标签的数据。
结果:
1.在完全监督的设置下,我们的双任务一致性正则化优于双任务的单独和联合监督。
2.在半监督设置下,所提出的框架在几个临床数据集上优于最先进的半监督医学图像分割框架。
3.与现有的方法相比,该框架需要较少的训练时间和计算成本。同时,它直接适用于任何半建议的医学图像分割场景,由于任务之间存在可微的变换,可以很容易地扩展到使用附加任务。
The consistency regularization plays a vital role in computer vision and image processing, especially in semi-supervised learning.
stochasitc transformations and perturbations 随机变换和扰动
不同于图像变换的结果一致性,一次训练需要前向两次,任务一致性只需要一次。
双任务一致性网络结构

下面是像素分类任务的head,什么上面是水平集回归任务的head。使用encoder-decoder作为backbone。模型在标注的数据上使用最小化监督损失L_{Dice},L_{LSF},在标注数据和非标注使用双任务一致性损失L_{DTC}。函数T在监督学习中将GT标签转换为水平集表示,函数T^{-1}将水平集函数变为概率图来计算L_{DTC}
方法
原论文使用3D图片作为输入
为了建立一致性,使用一个transform layer来将水平集函数转化为一个像素级别的概率图,通过实现smooth后的阶跃函数。
Heaviside函数,即阶跃函数
双任务一致性 :通常来说,一致性损失鼓励在数据集层面的预测(比如,同一个数据的不同变换的预测应该相同),不同于数据集的一致性,使用任务一致性。水平集函数是一个捕获活动轮廓和距离信息的传统方法。定义如下:

x,y是两个在分割mask中的不同的像素和体素,\partial S是zero水平集同时表示目标的轮廓。S_{in} \ S_{out}目标对象是内部区域和外部区域。
通过公式T(x)作为任务转换函数,将分割图变为水平集图。如何将水平集图变为分割图,很容易想到啊使用逆,但是T^{-1}不可微。使用一个平滑逼近来实现逆。

k,z是乘系数和像素/体素。这个非常像sigmoid函数可以作为激活函数。
我们将dual-task-consistency loss定义为\mathcal L_{DTC}来强化task1和task2的一致性

Dual Task Consistency SSL :令 D_l,D_u分别为labeled和unlabelled数据。D作为整个提供的数据集。Seg任务使用Dice loss。
水平集函数损失

最终的损失

\mathcal{L_{Seg},L_{LSF}}是全监督学习用到的损失,\mathcal{L_{DTC}}是无标签数据使用的损失。
\lambda_d的设置使用Guass warm up 函数
\lambda _d(t)=e^{-5(1-\frac{t}{t_{max}})^2}
t,t_{max}是当前步和最大训练步,图像如下

算法

源码阅读
Q1:这个函数这么实现的?

A:使用有符号的距离场,in和out使用tanh函数
def compute_sdf(img_gt, out_shape):
"""
compute the signed distance map of binary mask
input: segmentation, shape = (batch_size, x, y, z)
output: the Signed Distance Map (SDM)
sdf(x) = 0; x in segmentation boundary
-inf|x-y|; x in segmentation
+inf|x-y|; x out of segmentation
normalize sdf to [-1,1]
"""
img_gt = img_gt.astype(np.uint8)
normalized_sdf = np.zeros(out_shape)
for b in range(out_shape[0]): # batch size
posmask = img_gt[b].astype(np.bool)
if posmask.any():
negmask = ~posmask
posdis = distance(posmask)
negdis = distance(negmask)
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))
sdf[boundary==1] = 0
normalized_sdf[b] = sdf
# assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis))
# assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
return normalized_sdf
函数数据出来的是距离图

然后计算MSE损失
with torch.no_grad():
gt_dis = compute_sdf(label_batch[:].cpu(
).numpy(), outputs[:labeled_bs, 0, ...].shape)
gt_dis = torch.from_numpy(gt_dis).float().cuda()
loss_sdf = mse_loss(outputs_tanh[:labeled_bs, 0, ...], gt_dis)
Q2:为什么可以看作回归任务

按照这张图理解,sdf函数生成距离的就是这个图。模型作为一个万能的函数逼近器,去拟合一个水平集函数,所以可以当作回归任务。
Q3.逆转化怎么实现
如论文所说,sigmoid函数实现
outputs_soft = torch.sigmoid(outputs)
...
loss_seg_dice = losses.dice_loss(
outputs_soft[:labeled_bs, 0, :, :, :], label_batch[:labeled_bs] == 1)
dis_to_mask = torch.sigmoid(-1500*outputs_tanh)
consistency_loss = torch.mean((dis_to_mask - outputs_soft) ** 2)
实验
实验细节
train with 20% labeled and 80% unlabeled data
normalize: zero mean and unit variance (去均值,归一化 )
VNet→ for 3D seg
SGD optimizer,6000 iter,lr=0.01,decay 0.1 every 2500 iter,batch size=4, 2 label imgs and 2 unlabeled imgs(大图)
on-the-fly data aug,训练时数据增强
数据增强方式:random flipping,rotatiing with 90,180,270
水平集数据在训练之前已经的变换好了,因为水平集函数是变换不变的
评估指标:Dice Jaccard ASD(average surface distance), 95HD(95% Hausdorff Distance)
不同unlabeled data比例下,三种方法的Dice
实际实验的结果和论文的结果差不了太多,但是需要进行染色归一化。

