半监督图卷积网络:在医学图像分析和疾病诊断中的应用
作者:禅与计算机程序设计艺术
1.简介
近年来,随着医疗影像技术的飞速发展、生物医学信息技术的快速崛起、科技对疾病诊断过程的广泛关注等方面的原因,越来越多的人开始重视医学图像相关的机器学习技术,尤其是如何提升医学图像分析模型的性能,从而更好地进行疾病诊断。其中一种有效的方法就是利用人工标注的数据来增强模型的鲁棒性和健壮性,通过这种方式可以让模型更好的识别到样本中的潜在模式,进而提升模型的准确率。半监督学习(Semi-supervised Learning)正是利用这种方式来解决数据量较小的问题。
在本篇博文中,我将向读者展示如何利用基于CNN的半监督图卷积网络(Squeeze-and-Excitation Graph Convolutional Network,SE-GCN)来处理临床医学图像数据集,并分析其优缺点。
Squeeze-and-Excitation Graph Convolutional Network (SE-GCN) 是由苏黎世联邦理工大学等人于2019年发表的一篇研究工作。该论文提出了一种新型的图卷积网络结构——SE-GCN,它能够充分考虑节点间相互联系的信息,并且能较好地保持图数据的全局一致性。它在医疗图像分析领域的应用有广阔的空间,可以帮助提高诊断精度、降低患者症状不稳定性、减少因训练数据不足带来的欠拟合、提升模型鲁棒性及预测能力。因此,SE-GCN在疾病诊断领域具有很大的实用价值。
2.相关研究背景及意义
2.1 图神经网络的发展
Graph Neural Networks (GNNs) 在深度学习领域占据了一席之地,它们借助图论和其特有的计算理论来处理复杂的异构网络数据。图神经网络旨在对图结构数据进行建模,通过对节点的特征和边的关联关系进行学习,能够有效的捕获全局或局部的上下文信息,并且能够处理具有多个属性的节点。
近年来,由于GNN模型在多种任务上都取得了显著的成果,许多学者开始寻求将GNN运用于医疗图像分析领域。随着医疗影像技术的飞速发展,越来越多的人开始关心如何利用这些数据来构建高精度、可靠的诊断系统,因此建立有效的图神经网络模型就变得十分重要。
2.2 半监督学习的兴起
近年来,随着医疗影像技术的飞速发展,越来越多的人开始关心如何利用这些数据来构建高精度、可靠的诊断系统。为了达到这个目标,人们提出了各种各样的方法,包括分类、回归、聚类、密度估计等等。但是,仅靠手动标记的数据往往不能完全满足需求,因为数据往往存在类别分布不均衡的问题、特征分布差异明显、样本质量差异等问题。因此,为了解决这些问题,很多学者开始寻找其他的方式来利用人工标注的数据,其中半监督学习是一种重要的方法。
半监督学习的基本思想是利用部分人工标注的数据来训练模型,同时利用整个数据集来获得模型的泛化能力。即使部分数据被错误标记或者丢失,也可以通过使用整个数据集来对模型进行微调,从而提升模型的准确率。然而,由于现有方法无法处理非凸学习问题,因此在实际生产环境中使用仍然面临一些困难。另外,由于半监督学习中需要收集和标注大量无效的样本,因此在实际场景中实现起来也比较困难。因此,半监督学习的发展也成为一个重要方向。
3.基本概念术语说明
3.1 图神经网络
图神经网络(Graph Neural Network,GNN)是一个用深度学习方法处理图结构数据的神经网络模型。图是由节点(node)和连接着节点的边(edge)组成的,每个节点代表图中的一个实体,而每条边代表两个节点之间的关系。如同一般的神经网络一样,GNN也是一种参数化的函数,可以定义节点之间的映射关系,从而实现对图结构数据的学习。
3.2 图卷积网络
图卷积网络(Graph Convolutional Network,GCN)是GNN中的一种最流行的网络结构。GCN利用图的邻接矩阵进行信号的传递,把输入信号转换成输出信号。GCN的主要思想是从图结构中提取局部特征和全局特征,然后结合两者的联系得到最终的输出结果。具体来说,GCN首先利用图的邻接矩阵对节点之间进行正交变换,然后用卷积核进行卷积操作,从而完成节点特征的提取。再者,GCN还利用图自身的全局特性进行特征的整合,从而提升模型的表达能力。
3.3 SE模块
SE模块(Squeeze-and-Excitation Module,SE),是指一种可以扩张神经元的结构。SE模块的基本思想是通过池化和激活函数来减少特征图的冗余信息,然后利用额外的信息进行特征增强。具体来说,SE模块的基本结构为先执行全局平均池化(Global Average Pooling)操作,将全局的特征压缩到一个单独的向量;然后在输出通道上执行1x1的卷积操作,增加感受野;最后利用sigmoid激活函数得到权重系数,再将权重乘以原始特征,得到新的特征图。SE模块通过控制神经元的激活情况,可以帮助提升神经网络的特征抽象能力,并增强模型的非线性表示能力。
3.4 半监督学习
半监督学习(Semi-supervised Learning)是一种通过利用部分人工标注的数据来训练模型,同时利用整个数据集来获得模型的泛化能力的机器学习技术。这一技术旨在解决数据量较小的问题,通过利用全部的训练数据来实现模型的优化。其基本思路是利用一部分已经标注的样本数据和某些没有标注的样本数据,利用这部分数据和整个数据集进行训练和测试。如果模型在这部分数据上表现良好,就可以认为模型已经收敛,并进入一个更加泛化的状态。如果模型在这部分数据上的表现出现偏差,则可以利用更多的未标注的数据来进行调整,直到模型的泛化能力达到最佳。半监督学习在医学图像分析领域的应用十分广泛。
4.核心算法原理和具体操作步骤以及数学公式讲解
4.1 图卷积网络模型概览
图卷积网络模型是利用图卷积操作来获取图数据的全局表示的模型。具体来说,图卷积网络模型包含三个主要组件,即节点更新层、边更新层和输出层。节点更新层负责更新节点的特征,边更新层负责更新边的特征,输出层负责预测分类标签。具体的操作步骤如下:
-
对输入图进行变换和预处理:首先对图的邻接矩阵进行拉普拉斯算子操作,然后使用对称正态分布随机初始化节点的特征。
-
通过图卷积操作学习节点的特征:然后利用图卷积操作(也叫做图傅里叶变换)来学习节点的特征。具体来说,图卷积操作将图的邻接矩阵乘以节点的特征矩阵,从而完成节点更新。图卷积操作的公式如下:
- \hat{A}: 谱规范化后的邻接矩阵
- D: 度矩阵
- ilde{A}: 拉普拉斯矩阵
- U, \Sigma, V^T: SVD分解出的左右奇异矩阵
这里, ilde{A} 表示拉普拉斯矩阵,是A 的半正定形式,即A 中只有对角元素为0,其余元素都为正。\hat{A} 是拉普拉斯矩阵A 的谱规范化版本,即L=D^{-\frac{1}{2}}AD^{-\frac{1}{2}}。图卷积操作的作用是完成节点的特征的学习,得到每个节点的融合信息。
- 使用注意力机制增强节点特征的通用性:此时,节点的特征已经完成了学习,但还不能够直接应用到下游任务中。所以,需要引入注意力机制来增强节点特征的通用性。注意力机制可以学习到不同位置的节点对于图的全局特征贡献大小,从而对每个节点的特征增强。具体来说,可以使用一个1×1的卷积核来产生一个注意力向量,然后对所有节点的特征进行加权,得到增强后的特征。注意力机制的公式如下:
e_{ij}=a_j^{ op}W_g h_i,\ a_j=\sigma( ilde{h}_j)^T W_e\ \hat{h}_{i'}=f(\sum_{j\in N(i')} e_{ij} h_j)\ f(x)= ext{ReLU}(Wx+b)\qquad x\in R^{n_{ ext{input}}}\ ``` ilde{h}_j=[W_1h_j;W_2h_j] ```
* $N(i')$:节点 $i'$ 的邻居节点集合
* $\sigma(\cdot)$:激活函数
* $f(\cdot)$:非线性函数
此处,a_j 表示节点 j 的注意力向量, ilde{h}_j 表示节点 j 的拓展特征,f(\cdot) 函数的目的是对节点的拓展特征进行非线性变换,以增强特征的复杂性。
4.2 Squeeze-and-Excitation模块
SE模块是在GCN的基础上引入的一种结构,目的是增加神经网络的非线性表示能力,并增强模型的特征抽象能力。具体来说,SE模块主要由池化操作和激活函数组成。池化操作将特征图的大小缩小到一个固定的值,激活函数则用来控制神经元的激活情况。SE模块的结构如下图所示:

其中,z 为一个1×1的卷积核,y 为节点的拓展特征,\alpha 为节点的注意力系数,\gamma 和\beta 分别为控制注意力系数大小的参数。y 可以表示为 y=\gamma s+x ,其中 s 是节点的全局表示,x 是节点的原始表示。
4.3 半监督图卷积网络模型
半监督图卷积网络模型是在图卷积网络模型的基础上,采用了半监督学习的思想,同时加入了注意力机制的功能。具体来说,半监督图卷积网络模型包含三种不同的模块,分别为特征提取器、训练器和分类器。
4.3.1 特征提取器
特征提取器用来提取图的全局表示,具体来说,特征提取器包括一个初始化层、一个全局池化层和两个GCN块。其中,初始化层负责初始化节点的特征,全局池化层将全局特征作为整个图的表示。GCN块则用于提取局部特征和全局特征。GCN块由多个图卷积层和SE模块组成。每个图卷积层用于提取局部特征,而SE模块用于增加非线性度量。
4.3.2 训练器
训练器利用已经标注的样本数据和未标注的样本数据进行训练,具体来说,训练器包括一个辅助损失函数,一个优化器,一个联合损失函数。辅助损失函数的目的是增加模型的鲁棒性,防止过拟合;优化器负责完成模型参数的更新;联合损失函数包括训练器自己定义的损失函数和辅助损失函数的组合。
4.3.3 分类器
分类器根据节点的全局表示和节点的标签进行分类,并给出预测结果。具体来说,分类器包含一个全连接层和softmax激活函数。全连接层将节点的全局表示与节点的标签进行连接,并对两个特征进行一个线性组合,从而得到分类结果。softmax激活函数将分类结果转化成属于各个类别的概率分布。
总体来说,半监督图卷积网络模型包括以下几个主要模块:
-
初始化层:该层负责初始化节点的特征,并将初始特征矩阵保存在某个变量中。
-
全局池化层:该层用于计算全局池化的特征。
-
GCN块:该块包括两个图卷积层和两个SE模块,前两个图卷积层用于提取局部特征,后两个SE模块用于增加非线性度量。
-
训练器:该模块包含了一个辅助损失函数、一个优化器、一个联合损失函数。
-
分类器:该模块包含了一个全连接层和一个softmax激活函数。
4.4 模型代码实现和实验结果
4.4.1 数据集介绍
在本文中,我们将使用一个名为“微观照片”的医学图像数据集。该数据集包含60,000张微型X光胸片的原始图像,分为五类,每个类别包括1000张图。每张图的大小为1024x1024,且都是灰度图。这些图像主要是来自肿瘤学实验室。
4.4.2 数据准备
(1)下载数据集
首先,需要下载数据集。数据集可以从作者提供的网站https://sites.google.com/view/miccai2019neurips下载。下载完成之后,需要解压zip文件。解压之后,文件夹的名称为“Microscopy”(大小写敏感)。
mkdir Microscopy
mv *.zip Microscopy/
cd Microscopy
unzip [file name].zip
rm [file name].zip # 删除压缩包
(2)划分数据集
接下来,需要将数据集划分为训练集、验证集和测试集。通常情况下,训练集和验证集的比例为8:2。
import os
from sklearn.model_selection import train_test_split
root_dir = 'Microscopy'
class_names = sorted([c for c in os.listdir(os.path.join(root_dir)) if not c.startswith('.')])
train_val_classes, test_classes = train_test_split(class_names, test_size=0.2, random_state=42)
train_classes, val_classes = train_test_split(train_val_classes, test_size=0.25, random_state=42)
print('Training classes:', len(train_classes), ', Validation classes:', len(val_classes), ', Test classes:', len(test_classes))
运行以上代码之后,会得到训练集、验证集和测试集的类别数量。
(3)加载数据集
接下来,可以加载训练、验证和测试数据集。为了节约内存,可以只读取必要的图片。
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
class MicroscopyDataset(Dataset):
def __init__(self, root_dir, class_names, split='train', transform=None):
self.root_dir = root_dir
self.class_names = class_names
self.split = split
self.transform = transform
self.files = []
for cls in self.class_names:
img_dir = os.path.join(self.root_dir, cls)
files = os.listdir(img_dir)
num_samples = int(len(files) * 0.1) if split == 'train' else len(files) // 2
self.files += [(os.path.join(cls, file), cls) for file in files[:num_samples]]
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
path, label = self.files[idx]
image = Image.open(os.path.join(self.root_dir, path)).convert('L').resize((1024, 1024))
if self.transform:
image = self.transform(image)
return image, label
transform_dict = {
'train': transforms.Compose([transforms.ToTensor()]),
'val': transforms.Compose([transforms.ToTensor()])
}
dataset_dict = {}
for split in ['train', 'val']:
dataset_dict[split] = MicroscopyDataset(root_dir, train_classes + val_classes if split == 'train' else test_classes, split, transform=transform_dict[split])
print('{} set size: {}'.format(split, len(dataset_dict[split])))
4.4.3 定义模型
(1)导入库
首先,导入需要使用的库。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TopKPooling, SGConv
from torch_geometric.utils import degree
(2)定义模型结构
然后,可以定义模型结构。
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = SGConv(1, 16, K=2)
self.pool1 = TopKPooling(16, ratio=0.8)
self.conv2 = SGConv(16, 32, K=2)
self.pool2 = TopKPooling(32, ratio=0.8)
self.fc1 = nn.Linear(2*math.ceil(math.sqrt(512))+32, 64)
self.fc2 = nn.Linear(64, 5)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch)
x1 = torch.cat([global_mean_pool(x, batch), global_max_pool(x, batch)], dim=-1).squeeze(-1)
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch)
x2 = torch.cat([global_mean_pool(x, batch), global_max_pool(x, batch)], dim=-1).squeeze(-1)
x = torch.cat([x1, x2], dim=1)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
模型的基本结构为一个包含三个模块:图卷积层、池化层和全连接层。其中,图卷积层和池化层分别用来提取局部特征和全局特征,全连接层用来进行分类。
(3)定义损失函数和优化器
接下来,可以定义损失函数和优化器。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
(4)模型训练
最后,可以训练模型。
def train():
model.train()
total_loss = correct_nodes = total_nodes = 0
for i, data in enumerate(loader):
optimizer.zero_grad()
data = data.to(device)
out = model(data)
loss = criterion(out, data.y)
pred = out.max(dim=1)[1]
correct_nodes += pred.eq(data.y).sum().item()
total_nodes += data.num_nodes
total_loss += float(loss) * data.num_graphs
loss.backward()
optimizer.step()
return total_loss / len(loader.dataset), correct_nodes / total_nodes
def evaluate():
model.eval()
correct_nodes = total_nodes = 0
with torch.no_grad():
for data in loader:
data = data.to(device)
out = model(data)
pred = out.max(dim=1)[1]
correct_nodes += pred.eq(data.y).sum().item()
total_nodes += data.num_nodes
return correct_nodes / total_nodes
