医疗影像分割 | 使用 Swin UNETR 训练自己的数据集(3D医疗影像分割教程)

代码地址:unetr
论文地址:https://arxiv.org/pdf/2201.01266
一、下载代码
从GitHub处获取代码后,依次访问SWINUNETR平台,在最初的两个涉及BRATS21和BTCV的数据集上执行相关操作,请注意仅需完成BRATS21文件夹的操作即可完成整个流程。
- BRATS21 专注于精确定位和分析脑肿瘤的结构特征,在融合多种影像学信息的同时提供细致的解剖学标记, 并被广泛应用于研究与评估脑部肿瘤自动分割算法.
- BTCV 则专门针对腹部 CT 成像数据进行多器官分割任务, 通过系统化的标注过程为多个腹部解剖结构提供标准标注, 是评价多器官自动分割技术的重要基准.
cd BRATS21
二、创建环境
创建编译环境
conda create -n SwinUNETR python==3.8
conda activate SwinUNETR
pip install -r requirements.txt
在遇到无法通过git+https://github.com/Project-MONAI/MONAI.git@07de215c下载依赖项的情况时,请按照以下步骤进行操作:首先将requirements.txt重新配置为以下内容;随后运行pip install -r requirements.txt命令即可完成安装。
# git+https://github.com/Project-MONAI/MONAI.git@07de215c
monai==1.1.0
nibabel==3.1.1
tqdm==4.59.0
einops==0.4.1
tensorboardX==2.1
tensorboard
# conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=11.8 -c pytorch -c nvidia
三、数据集准备
1. 准备数据集文件
数据格式规定:你所使用的 MONAI 数据加载器需遵循特定的数据组织规范。根据代码文档的说明,在处理医学影像时,默认情况下应准备好一组图像与对应标注信息。这一规范通常意味着你需要准备好一组图像与标注对(pair)。其中的图像文件一般采用 .nii 或 .nii.gz 格式存储;而标注信息则可能采用相同格式或其他二进制或标签文件格式的形式存储
images:属于输入的医学图像数据集。这些数据通常以.nii格式或.nii.gz格式存储。labels:属于这些医学图像对应的标签数据集。这些标签通常用于模型训练和验证过程,并且通常作为医学图像分割任务中的标注信息。
在本系统中,该函数读取的信息构成一个JSON文件。其中包含训练与验证数据的相关路径信息。请确认您提供的JSON文件具有正确的结构与格式,请参考以下示例:
{
"training": [
{
"image": "path_to_image1.nii.gz",
"label": "path_to_label1.nii.gz",
"fold": 0
},
{
"image": "path_to_image2.nii.gz",
"label": "path_to_label2.nii.gz",
"fold": 1
},
...
],
"validation": [
{
"image": "path_to_image3.nii.gz",
"label": "path_to_label3.nii.gz",
"fold": 0
},
...
]
}
- image 和 label 表示它们是图像和标签文件的路径,并且通常采用
.nii.gz格式。- fold 常用于指定数据集划分的方式,在交叉验证中被广泛应用。当进行多轮折叠(fold)时,请确保合理分配训练集与验证集以避免数据泄漏。
2. 目录结构
数据集的目录结构应如下所示:
/data
/image1.nii.gz
/label1.nii.gz
/image2.nii.gz
/label2.nii.gz
/...
/annotations
annotations.json
3. 或者选择根据自己的数据集修改代码
基于我的数据集位于同一个文件夹内,并且遵循以下命名格式:xxxxx.nii.gz及其对应的seg版本xxxxx_seg.nii.gz。经过分析与调整后。
首先,写一个自己数据集的json文件,generate_json.py :
import os
import json
from collections import OrderedDict
import random
# 设置划分比例
train_ratio = 0.8
validation_ratio = 0.1
test_ratio = 0.1
# 假设数据集路径在datapath下,数据文件名形式为xxx.nii.gz和xxx_seg.nii.gz
datapath = r"" # 输入数据路径
image_files = [f for f in os.listdir(datapath) if f.endswith('.nii.gz') and '_seg' not in f]
label_files = [f for f in os.listdir(datapath) if f.endswith('_seg.nii.gz')]
# 确保图像和标签文件数量相同
assert len(image_files) == len(label_files), "The number of images and labels must be the same."
# 计算各个数据集的大小
total_size = len(image_files)
train_size = int(total_size * train_ratio)
validation_size = int(total_size * validation_ratio)
test_size = total_size - train_size - validation_size
# 划分数据集
train_images = image_files[:train_size]
train_labels = label_files[:train_size]
validation_images = image_files[train_size:train_size + validation_size]
validation_labels = label_files[train_size:train_size + validation_size]
test_images = image_files[train_size + validation_size:]
test_labels = label_files[train_size + validation_size:]
# 打印结果
print(f"Training set: {len(train_images)} images")
print(f"Validation set: {len(validation_images)} images")
print(f"Test set: {len(test_images)} images")
# 创建json字典
json_dict = OrderedDict()
json_dict['name'] = "" # 输入数据集名称
json_dict['tensorImageSize'] = "3D"
json_dict['release'] = "0.0"
json_dict['modality'] = {
"0": "MR" # 你可以根据实际情况修改
}
json_dict['labels'] = {
"0": "Background",
"1": "1",
"2": "2", # 标签根据实际情况修改
}
json_dict['numTraining'] = len(train_images)
# 假设使用 5-fold 交叉验证,这里添加一个随机划分 `fold`
random.seed(42) # 保证结果的可重复性
folds = [i % 5 for i in range(len(train_images))] # 按照模5划分折叠
# 生成训练集数据
json_dict['training'] = []
for i, img_file in enumerate(train_images):
fold = folds[i]
label_file = img_file.replace('.nii.gz', '_seg.nii.gz')
json_dict['training'].append({
'image': f"{datapath}/{img_file}",
'label': f"{datapath}/{label_file}",
'fold': fold # 添加fold信息
})
# 生成验证集数据
json_dict['validation'] = []
for i, img_file in enumerate(validation_images):
fold = 0 # 验证集通常属于fold 0
label_file = img_file.replace('.nii.gz', '_seg.nii.gz')
json_dict['validation'].append({
'image': f"{datapath}/{img_file}",
'label': f"{datapath}/{label_file}",
'fold': fold # 添加fold信息
})
# 生成测试集数据
json_dict['testing'] = []
for i, img_file in enumerate(test_images):
fold = 0 # 测试集通常属于fold 0
label_file = img_file.replace('.nii.gz', '_seg.nii.gz')
json_dict['testing'].append({
'image': f"{datapath}/{img_file}",
'label': f"{datapath}/{label_file}",
'fold': fold # 添加fold信息
})
# 写入json文件
json_file_path = os.path.join("dataset.json")
with open(json_file_path, 'w') as f:
json.dump(json_dict, f, indent=4, sort_keys=True)
print(f"JSON file saved to {json_file_path}")
然后修改 main.py 文件中预训练模型的名称、路径:


四、下载预训练模型
点击这个model就可以下载了,自己选一个model吧。

注意修改预训练模型的路径:

五、训练开始
现在一切准备就绪,开始训练!
python main.py --use_checkpoint


