Advertisement

【论文阅读】Unsupervised Learning of Image Segmentation Based on Differentiable Feature Clustering

阅读量:

文章目录

    • 摘要
      • 目标
      • 贡献点
  • 概述

  • 相关领域的研究现状

  • 传统方法综述

    • 深度学习技术

    • 基于人工标注数据的图像分割

    • 弱监督分割方法

    • 端到端可微分化的CNN模型

    • 方法

      • 问题建模
      • 网络结构
      • 损失函数
      • 网络更新

实验结果表明

复制代码
* 源代码解析
* 加scribble的运行结果

摘要

目标

  • 具有相似特性的像素应当将被划分为同类别。
  • 在空间上位置较接近的像素应当将被划分为同类别。
  • 类别的数量尽量保持较多的数量。
  • 这些指标之间存在相互制约的因素, 但需要实现整体协调管理。

贡献点

基于归一化处理和argmax运算实现可微分聚类。
基于空间连续性的损失函数。
增加了支持用户输入涂鸦的功能,并提升了结果的准确性。
引入了一个预训练模型以扩展现有方法。

介绍

详细阐述了一个扩展性的观点,并补充说明了之前研究中所采用的技术框架:超像素提取方法配合线性迭代聚类算法仅能保证空间连续性的特性

相关工作

经典方法

K-means:向量化的主要方案
*GS(基于图论的方法):采用特定的区域判别函数,在全局或局部特征上实施逐步优化的选择机制

深度学习

  • MsLRR: 基于标签的学习方法既适用于有标签数据也适用于无标签数据;但由于其依赖于超级像素划分这一特点,在某种程度上与之前工作中提出的方法存在相同的局限性。
  • W-Net: 采用一种完全独立于标签的方式估计分割后的图像并将其恢复为原始图像;这种设计使其能够不受边界限制。
  • Unsupervised learning of foreground object segmentation: 这是一种完全不依赖标签的学习方法;然而它仅专注于前景与背景的分割问题。

基于用户输入的图像分割

  • GraphCut 旨在优化将图像 pixels 与 nodes 相关联的成本,并可应用于 strokes 或 anchor boxes.
  • ImageMatting 将其视为提取对象并将其划分为前景与背景的过程;其中,在 ImageMatting 中, 图切割将每个 pixel 划分为前景或背景.
  • ConstrainedRandomWalks 通过利用 strokes 划分生成前/后景初始种子位置.

这些方法均仅生成单比特掩膜矩阵,在采用无监督的多标签分割技术的基础上,现有若干扩展算法方案可选。

α-expansion:确定一个局部最小值以减少α标签像素不易被添加
α-β swap:确定一种方法以避免两种标签容易发生互换

基于CNN的弱监督分割

常用的语义分割弱监督标注:常见于物体检测中的定位框、图像分类任务中的识别结果以及草图形式的涂鸦。一般流程则为:基于弱监督标注构建一个训练目标,并通过该训练目标对网络进行训练;这两步循环往复地执行以完成模型优化。

  • ScribbleSup:将涂鸦实例划分为超像素单元后覆盖整个图像区域并进行统一化的语义分割训练
    • e-SVM:基于约束条件的金字塔方法(CPMC)对锚定点所对应的像素区域进行精确分割然后结合多级分类器完成目标检测
    • 通过基于类别标签生成类特定显重要性图谱的方法在弱监督语义分割框架中构建全卷积CRF网络模型

这些传统弱监督方法可能存在无法准确收敛至预期结果的风险,并且包括一系列优化后的端到端卷积神经网络模型。

端到端的可微分割CNN

关于图像分割的深度学习研究一致围绕对图像特征的理解和提取。

  • deep embedded clustering (DEC):旨在通过最小化Kullback-Leibler散度损失进行学习。
  • 与之相比,本文提出的方法仅通过最优化交叉熵损失函数即可实现目标。
  • 最大边缘聚类法是一种基于最大边缘分布的学习方法。
  • 判别式聚类则是一种结合了判别信息与聚类目标的学习框架。

方法

问题建模

用于提取特征的机制为f, 负责分配标签的任务为g, 标签变量为c。在无监督方法中, 在固定不变的f和g作用下(或环境下), c需要被学习或确定; 在有监督方法中, 在固定不变的c条件下(或环境下), f和g需要被学习; 将整个问题分解为两个独立的子问题。

  1. 用固定的f和g优化c
  2. 再用固定的c优化f和g

网络结构

在这里插入图片描述

将RGB图像经过特征提取后,在q维聚类空间(其中q=3)上应用1x1卷积进行转换。沿着该聚类空间中的各个维度方向上实施批量归一化方法对转换后的特征向量进行标准化处理。采用argmax函数确定每个像素对应的潜在类别维度,并基于所获得的所有像素伪标签计算特征相似度损失与空间一致性损失的基础上实现模型优化训练


忽略批次影响的情况下,在我的理解中维度变化过程如下:H×W⇒提取特征并进行空间转换;H×W×q⇒确定伪标签之后;恢复到原始尺寸

在这里插入图片描述

当进行网络训练时,在初始阶段应设定一个较大的参数值 q 。随着模型损失函数值的降低,该参数值会逐渐减小以适应优化过程的需求。为了避免出现 q 趋近于1的情况,并确保模型输出的有效性,在完成优化后应对响应图进行归一化处理以维持数值稳定性。

损失函数

  • 基本:
在这里插入图片描述
  • 加入涂鸦:
在这里插入图片描述
  • 特征相似误差:
在这里插入图片描述

通过计算r_n的最大值来确定c_n。只有在遍历到对应于c_nr_rn时才会对ln值进行累加操作。由于经过归一化处理后所有r_rn的取值范围限定在0至1之间,在此计算中会引入一个负号。

  • 空间连续性损失:
在这里插入图片描述

计算每个像素上下左右的的response map上的值的差别

  • 涂鸦损失:
在这里插入图片描述

网络更新

前面所述的将复杂问题划分为两个子任务, 其实质即为卷积神经网络(CNN)的前向传播过程以及反向传播机制. 在优化过程中采用随机梯度下降算法, 并应用Xavier权重初始化策略.

其核心理念在于,在设计神经网络层时应确保输入与输出均服从相同均值与方差的正态分布状态。https://www.zhihu.com/search?type=content&q=Xavier%20%E5%88%9D%E5%A7%8B%E5%8C%96

当神经网络每一层的输入与输出均能保持正态分布且方差相近时,在训练过程中能够有效规避了梯度弥散的风险。

另一个关键点在于相较于带明确标签的监督学习方法,在本技术中,在最后一层卷积神经网络(CNN)与argmax层之间施加了一个批标准化(batch norm)操作至关重要。该操作将response map中的每个轴归一化至均值为零、方差为一的状态。这一处理步骤使得各轴能够均等地参与后续比较过程,并最终准确地确定相应的类别标签。

实验结果

连续性损失的有效性

针对特征相似性损失与连续性损失之间的比率问题,在不同数据集下基于所需的分割精度设定不同的比率能够显著提升性能表现。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

用户输入涂鸦的分割

在这里插入图片描述

参考图像预训练的效果

在这里插入图片描述

源代码解析

github地址:https://github.com/kanezaki/pytorch-unsupervised-segmentation-tip/blob/master/demo.py

注意:此处无法进一步优化文字内容

复制代码
    import argparse
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torchvision import datasets, transforms
    from torch.autograd import Variable
    import cv2
    import sys
    import numpy as np
    import torch.nn.init
    import random
    
    use_cuda = torch.cuda.is_available()
    
    parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation')
    parser.add_argument('--scribble', action='store_true', default=False, 
                    help='use scribbles')
    parser.add_argument('--nChannel', metavar='N', default=100, type=int, 
                    help='number of channels')
    parser.add_argument('--maxIter', metavar='T', default=1000, type=int, 
                    help='number of maximum iterations')
    parser.add_argument('--minLabels', metavar='minL', default=3, type=int, 
                    help='minimum number of labels')
    parser.add_argument('--lr', metavar='LR', default=0.1, type=float, 
                    help='learning rate')
    parser.add_argument('--nConv', metavar='M', default=2, type=int, 
                    help='number of convolutional layers')
    parser.add_argument('--visualize', metavar='1 or 0', default=1, type=int, 
                    help='visualization flag')
    parser.add_argument('--input', metavar='FILENAME',
                    help='input image file name', required=True)
    parser.add_argument('--stepsize_sim', metavar='SIM', default=1, type=float,
                    help='step size for similarity loss', required=False)
    parser.add_argument('--stepsize_con', metavar='CON', default=1, type=float, 
                    help='step size for continuity loss')
    parser.add_argument('--stepsize_scr', metavar='SCR', default=0.5, type=float, 
                    help='step size for scribble loss')
    args = parser.parse_args()
    
    # CNN model
    class MyNet(nn.Module):
    def __init__(self,input_dim):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, args.nChannel, kernel_size=3, stride=1, padding=1 )
        self.bn1 = nn.BatchNorm2d(args.nChannel)
        self.conv2 = nn.ModuleList()
        self.bn2 = nn.ModuleList()
        # 参数里面nConv设为2,所以这里的conv2也只包含一个卷积层,输入和输出通道都是100
        for i in range(args.nConv-1):
            self.conv2.append( nn.Conv2d(args.nChannel, args.nChannel, kernel_size=3, stride=1, padding=1 ) )
            self.bn2.append( nn.BatchNorm2d(args.nChannel) )
        # 最后一层是1x1的卷积核,输出为100,即q=100,分类数如参数设置为3-100,随着网络更新动态变化
        self.conv3 = nn.Conv2d(args.nChannel, args.nChannel, kernel_size=1, stride=1, padding=0 )
        self.bn3 = nn.BatchNorm2d(args.nChannel)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu( x )
        x = self.bn1(x)
        for i in range(args.nConv-1):
            x = self.conv2[i](x)
            x = F.relu( x )
            x = self.bn2[i](x)
        x = self.conv3(x)
        x = self.bn3(x)
        return x
    
    # load image
    im = cv2.imread(args.input)
    # 把图像(H,W,C)变成(C,H,W),并把像素值归一化到0-1之间
    data = torch.from_numpy( np.array([im.transpose( (2, 0, 1) ).astype('float32')/255.]) )
    if use_cuda:
    data = data.cuda()
    data = Variable(data)
    
    # load scribble
    if args.scribble:
    	# 这里是读取之前准备好的二值的涂鸦图片
    mask = cv2.imread(args.input.replace('.'+args.input.split('.')[-1],'_scribble.png'),-1)
    # reshape成一维,长度=HxW
    mask = mask.reshape(-1)
    # 去除重复数字
    mask_inds = np.unique(mask)
    # 删掉255,剩下的就是涂鸦上的颜色
    mask_inds = np.delete( mask_inds, np.argwhere(mask_inds==255) )
    # 返回mask中=255的索引(空白)
    inds_sim = torch.from_numpy( np.where( mask == 255 )[ 0 ] )
    # 返回mask中!=255的索引(画了涂鸦的像素)
    inds_scr = torch.from_numpy( np.where( mask != 255 )[ 0 ] )
    # mask的int型,这里要把源代码的np.int改成np.int64(因为我会报一个类型不统一的错)
    target_scr = torch.from_numpy( mask.astype(np.int64) )
    if use_cuda:
        inds_sim = inds_sim.cuda()
        inds_scr = inds_scr.cuda()
        target_scr = target_scr.cuda()
    target_scr = Variable( target_scr )
    # set minLabels
    # 根据涂鸦上的颜色类别确定最小的,按照readme里面的测试,剩下0和8两种值,我猜想8应该是代表那些涂鸦线条的边缘
    args.minLabels = len(mask_inds)
    
    # train
    model = MyNet( data.size(1) )
    if use_cuda:
    model.cuda()
    model.train()
    
    # similarity loss definition
    loss_fn = torch.nn.CrossEntropyLoss()
    
    # scribble loss definition
    loss_fn_scr = torch.nn.CrossEntropyLoss()
    
    # continuity loss definition
    loss_hpy = torch.nn.L1Loss(size_average = True)
    loss_hpz = torch.nn.L1Loss(size_average = True)
    
    HPy_target = torch.zeros(im.shape[0]-1, im.shape[1], args.nChannel)
    HPz_target = torch.zeros(im.shape[0], im.shape[1]-1, args.nChannel)
    if use_cuda:
    HPy_target = HPy_target.cuda()
    HPz_target = HPz_target.cuda()
    
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    # 随机生成100种颜色作为标签颜色
    label_colours = np.random.randint(255,size=(100,3))
    
    for batch_idx in range(args.maxIter):
    # forwarding
    optimizer.zero_grad()
    output = model( data )[ 0 ]
    # 从(C,H,W)转到(H,W,C),再用contiguous做一个拷贝和之前的output区分开,再view成(HxW,C)的形状
    output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel )
    	# 不太明白这里为什么用reshape而不是permute,这里得到的就是100个通道的response map
    outputHP = output.reshape( (im.shape[0], im.shape[1], args.nChannel) )
    # 求空间连续性误差,先y和z方向分别求出右边像素-左边像素(左右指的是索引)
    HPy = outputHP[1:, :, :] - outputHP[0:-1, :, :]
    HPz = outputHP[:, 1:, :] - outputHP[:, 0:-1, :]
    lhpy = loss_hpy(HPy,HPy_target)
    lhpz = loss_hpz(HPz,HPz_target)
    	# 返回的ignore代表每个像素对应的100个通道中值最大的那个通道的值,target返回的是那个通道对应的索引(也就是标签)
    ignore, target = torch.max( output, 1 )
    # target的形状应该是(HxW,1),即每一个像素都被分配了标签
    im_target = target.data.cpu().numpy()
    # 去除重复的标签得到标签总数
    nLabels = len(np.unique(im_target))
    if args.visualize:
    	# 按照HxW的形状,对每一个像素按照标签赋值之前随机初始化的100种颜色
        im_target_rgb = np.array([label_colours[ c % args.nChannel ] for c in im_target])
        # 把(HxW,3)变成(H,W,3)
        im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 )
        cv2.imshow( "output", im_target_rgb )
        cv2.waitKey(10)
    
    # loss 
    if args.scribble:
        loss = args.stepsize_sim * loss_fn(output[ inds_sim ], target[ inds_sim ]) + args.stepsize_scr * loss_fn_scr(output[ inds_scr ], target_scr[ inds_scr ]) + args.stepsize_con * (lhpy + lhpz)
    else:
        loss = args.stepsize_sim * loss_fn(output, target) + args.stepsize_con * (lhpy + lhpz)
        
    loss.backward()
    optimizer.step()
    
    print (batch_idx, '/', args.maxIter, '|', ' label num :', nLabels, ' | loss :', loss.item())
    
    if nLabels <= args.minLabels:
        print ("nLabels", nLabels, "reached minLabels", args.minLabels, ".")
        break
    
    # save output image
    if not args.visualize:
    output = model( data )[ 0 ]
    output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel )
    ignore, target = torch.max( output, 1 )
    im_target = target.data.cpu().numpy()
    im_target_rgb = np.array([label_colours[ c % args.nChannel ] for c in im_target])
    im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 )
    cv2.imwrite( "output.png", im_target_rgb )
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    
    代码解读

加scribble的运行结果

复制代码
    python demo.py --input ./PASCAL_VOC_2012/2007_001774.jpg --scribble
    
    
      
    
    代码解读
在这里插入图片描述

计划后续使用其他图片搭配scribble进行测试。实际上发现scribble具有独特的特性:它不仅是在图片表面画线段那么简单的内容物点出来吗?具体来说,在scribble上绘制的线条应当是数值为零的区域;通过一个介于0和255之间的数值来定义轮廓边界;特别的是这个边界值决定了对象与背景之间的区分度区域在哪里呢?目前对于这一机制的具体工作原理还存在疑问;待进一步学习后才能更好地理解其工作原理。

全部评论 (0)

还没有任何评论哟~