医疗影像分类 | 阿尔兹海默症分类识别(2D+3D模型)(数据集为3D MRI扫描图像)
发布时间
阅读量:
阅读量
该项目旨在识别阿尔茨海默病(Alzheimer's Disease)患者的大脑头部3D MRI扫描图像中的三种类别:健康样本、轻度认知障碍样本和阿尔茨海默症样本。项目包含数据集准备、模型构建及训练等内容。
数据集
数据集为每人头部3D MRI扫描图像,包含三种类别共约900张图像。每个MRI序列由多个切片组成(长宽切片数量),单个 MRI sequence 为三维张量。
模型架构
- 2D 模型基于 ResNet50 预训练网络进行设计,输入为 79x95x79 的三维张量。
- 3D 模型基于 LeNet3D 网络架构设计,输入为三维张量 (1, 79, 95, 79)。
训练与优化- 使用 Adam 优化器和交叉熵损失函数进行训练。
- 在独立测试集中以尽量高的准确率区分三种类别。
- 对于 2D 模型,在 training set 上分别达到约86% 的准确率;对于 3D 模型,在 validation set 上分别达到约84% 的准确率。
实验结果与分析- 在独立测试集中分别以高准确率区分三种类别。
- 绘制了 train loss 和 val loss 曲线,并保存了预测结果图和 CSV 文件供进一步分析。
阿尔兹海默症分类识别
-
- 项目介绍
- 训练集样式
- 可视化数据集,保存为gif
- ————————————————————————————————
-
2D模型
-
3D模型
-
- 测试集样式
- ————————————————————————————————
项目介绍
基于人体头部3D MRI扫描成像的数据集被划分为三个类别:分别为健康样本、轻度认知障碍样本以及阿尔茨海默症患者样本。研究团队通过利用该影像数据对模型算法进行训练,在独立测试集上尽可能提高分类准确性;每个样本均为三维体征的空间信息表示。
MRI 数据:每个 MRI sequence 包含大量切片构成的一个3D图像。该图像由长度、宽度以及切片数量构成。进而可知,单个的 MRI sequence 具备三个维度的信息:长度、宽度和厚度,因此可以作为一个三维张量。

训练集样式

可视化数据集,保存为gif
import h5py as h5
from PIL import Image
import imageio
train = h5.File('train/train_pre_data.h5','r') # 读取数据
one_sample = train['data'][0,0]
frames = []
for layer_img in one_sample:
img = Image.fromarray(layer_img).convert('L') # 先转换为image,再转为灰度图
img.resize((79*5, 95*5),Image.ANTIALIAS).save('temp.jpg') # 放大5倍并保存为temp.jpg
frames.append(imageio.imread('temp.jpg')) # 存入frame列表
imageio.mimsave('{0}.gif'.format('idx'), frames, 'GIF', duration = 0.1) # 保存为gif格式
————————————————————————————————
2D模型
import os
import h5py
import numpy as np
from keras.utils import np_utils
import pandas as pd
from keras.applications import resnet
from sklearn.model_selection import train_test_split
import tensorflow.keras as keras
from keras.layers import Dense,GlobalAvgPool2D
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
train_dir = 'train'
train_data = 'train_pre_data.h5'
train_label = 'train_pre_label.csv'
train_epochs = 100
#读取训练数据
train = h5py.File(os.path.join(train_dir,train_data),'r')
#读取标签
labels = pd.read_csv(os.path.join(train_dir,train_label))
#将数据预处理,并且分为训练集和测试集
features = np.array(train['data'])
features = features.reshape(300,79,95,79)
X_train, X_test, y_train, y_test = train_test_split(features,labels['label'].values,test_size = 0.3,random_state = 42)
#对标签分为三类,进行独热码处理
y_train = np_utils.to_categorical(y_train,num_classes=3)
y_test = np_utils.to_categorical(y_test,num_classes=3)
#神经网络,用ImageNet ResNet50预训练模型
num_classes = 3
inputdim = (79,95,79)
base_model =resnet.ResNet50(include_top=False, weights = None, input_shape = inputdim)
x = base_model.output
#GlobalAvgPool2D是将输入特征图的每一个通道求平均得到一个数值。
x = GlobalAvgPool2D()(x)
#三个全连接层
x = Dense(64,activation='relu')(x)
x = Dense(32,activation='relu')(x)
x = Dense(num_classes,activation='softmax')(x)
model = keras.Model(inputs = base_model.input, outputs =x)
print(model)
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()
print('Training---------')
# #保存模型设置
# #
# checkpointer = keras.callbacks.ModelCheckpoint(os.path.join("", 'model_{epoch:03d}.hdf5'),
# verbose=1, save_weights_only=False, period=train_epochs)
history = model.fit(X_train,y_train, epochs = train_epochs,batch_size=32)
# print(history.history)
#绘制 train loss
plt.figure()
plt.plot(history.epoch, history.history['loss'], label = 'loss')
plt.legend()
plt.savefig("train_loss.png")
plt.close()
#绘制 accuracy
plt.figure()
plt.plot(history.epoch, history.history['accuracy'], label = 'Accuracy')
plt.legend()
plt.savefig("Accuracy.png")
plt.close()
print('\nTesting---------')
loss,accuracy = model.evaluate(X_test,y_test)
print('\ntest loss',loss)
print('\ntest accuracy',accuracy)
3D模型
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
from torch.autograd import Variable
class LeNet3D(nn.Module):
def __init__(self,num_classes=3):
super(LeNet3D, self).__init__()
self.conv1 = nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1)
self.pool1 = nn.MaxPool3d(2, 2)
self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
self.pool2 = nn.MaxPool3d(2, 2)
self.fc1 = nn.Linear(4800, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
# print(x.size())
# torch.Size([16, 1, 79, 95, 79])
x = F.relu(self.conv1(x)) # torch.Size([16, 16, 40, 48, 40])
# print(x.size())
x = self.pool1(x) # torch.Size([16, 16, 20, 24, 20])
# print(x.size())
x = F.relu(self.conv2(x)) # torch.Size([16, 32, 10, 12, 10])
# print(x.size())
x = self.pool2(x) # torch.Size([16, 32, 5, 6, 5])
# print(x.size())
x = x.view(x.size(0), -1) # torch.Size([16, 4800])
# print(x.size())
x = F.relu(self.fc1(x)) # torch.Size([16, 120])
# print(x.size())
x = F.relu(self.fc2(x)) # torch.Size([16, 84])
# print(x.size())
x = self.fc3(x) # output(3)
return x
def main_3d():
model = LeNet3D(num_classes = 3)
model = nn.DataParallel(model, device_ids=None)
print(model)
input_var = Variable(torch.randn(16, 1, 32, 64, 64)) # b,c,z,h,w
output = model(input_var)
print(output.shape)
import pandas as pd
import torch
from torch.utils import data as torch_data
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from models import LeNet3D
import h5py
import os
import numpy as np
import matplotlib.pyplot as plt
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validation(valid_loader, path_ckpt):
model = LeNet3D()
model_ckpt = torch.load(path_ckpt)
model.load_state_dict(model_ckpt['model_state_dict'])
# model = torch.nn.DataParallel(model).to(device)
model.eval()
model.to(device)
loss_sum = 0
acc_sum = 0
for step, (data, label) in enumerate(valid_loader):
img = data.to(device)
# print(img.shape)
targets = label.to(device)
outputs = model(img).squeeze(1)
loss = F.cross_entropy(outputs, torch.max(targets, 1)[1]).to(device)
loss_sum += loss.detach().item()
prediction = torch.max(outputs, 1)[1]
pred_y = prediction.data.cpu().numpy()
target = torch.max(targets, 1)[1]
target_y = target.data.cpu().numpy()
acc_sum += sum((pred_y-target_y)==0)
loss_avg = loss_sum / len(valid_loader)
return loss_avg, acc_sum
class DataFromH5CSVFile(torch_data.Dataset):
def __init__(self,data,label):
self.hr = label
self.lr = data
def __getitem__(self, idx):
if self.hr[idx] == 0:
label = torch.from_numpy(np.array([1,0,0])).float()
elif self.hr[idx] == 1:
label = torch.from_numpy(np.array([0,1,0])).float()
else:
label = torch.from_numpy(np.array([0,0,1])).float()
data = torch.from_numpy(self.lr[idx]).float()
return data, label
def __len__(self):
assert self.hr.shape[0] == self.lr.shape[0], "Wrong data length"
return self.hr.shape[0]
def train():
MAX_EPOCH = 100
ITR_PER_CKPT_VAL = 1
train_loss = []
val_acc = []
val_loss = []
h5File = h5py.File("train/train_pre_data.h5", 'r')
labels = pd.read_csv(os.path.join("train/train_pre_label.csv"))
train_data = DataFromH5CSVFile(np.array(h5File['data'][:250]), np.array(labels['label'].values[:250]))
print("train_data:",len(train_data))
valid_data = DataFromH5CSVFile(np.array(h5File['data'][200:]), np.array(labels['label'].values[200:]))
print("valid_data:",len(valid_data))
train_loader = torch_data.DataLoader(train_data, batch_size=64,
shuffle=True, num_workers=4, pin_memory=False)
valid_loader = torch_data.DataLoader(valid_data, batch_size=1,
shuffle=False, num_workers=4, pin_memory=False)
model = LeNet3D()
# model = torch.nn.DataParallel(model).to(device)
model.train()
model.to(device)
print(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
best_valid_score = 0
writer = SummaryWriter(comment='Linear')
for i_epoch in range(1, MAX_EPOCH + 1):
loss_sum = 0
N = 0
for step, (data, label) in enumerate(train_loader):
img = data.to(device)
targets = label.to(device)
outputs = model(img).squeeze(1)
loss = F.cross_entropy(outputs, torch.max(targets, 1)[1]).to(device)
loss_sum += loss.detach().item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_avg = loss_sum / len(train_loader)
print("[Epoch " + str(i_epoch) + " | " + "train loss = " + ("%.7f" % loss_avg) + "]")
writer.add_scalar('scalar/train_loss', loss_avg, i_epoch)
train_loss.append(loss_avg)
if i_epoch % ITR_PER_CKPT_VAL == 0:
# Saving checkpoint.
path_ckpt = r"checkpoints/" + str(i_epoch) + ".pth.tar"
torch.save({"epoch": i_epoch, "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()}, path_ckpt)
loss_val, acc_sum = validation(valid_loader, path_ckpt)
accuracy = acc_sum * 100 / len(valid_loader)
print("[Epoch " + str(i_epoch) + " | " + "val loss = " + ("%.7f" % loss_val) + " accuracy = " + ("%.3f" % accuracy) + "%]")
writer.add_scalar('scalar/val_loss', loss_val, i_epoch)
writer.add_scalar('scalar/val_acc', accuracy, i_epoch)
val_acc.append(accuracy)
val_loss.append(loss_val)
if best_valid_score < accuracy:
path_ckpt_best = r"checkpoints/best_acc.pth.tar"
torch.save({"epoch": i_epoch, "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()}, path_ckpt_best)
best_valid_score = accuracy
writer.close()
#绘制 train loss
plt.figure()
plt.plot(range(1, MAX_EPOCH + 1), train_loss, label = 'train_loss')
plt.plot(range(1, MAX_EPOCH + 1), val_loss, label = 'val_loss')
plt.legend()
plt.savefig("train_loss.png")
plt.close()
#绘制 accuracy
plt.figure()
plt.plot(range(1, MAX_EPOCH + 1), val_acc, label = 'val_accuracy')
plt.legend()
plt.savefig("val_accuracy.png")
plt.close()
print("best_valid_score:", best_valid_score)
if __name__=='__main__':
train()
测试集样式

import pandas as pd
import torch
from torch.utils import data as torch_data
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from models import LeNet3D
import h5py
import os
import numpy as np
import matplotlib.pyplot as plt
import csv
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validation(valid_loader, path_ckpt):
model = LeNet3D()
model_ckpt = torch.load(path_ckpt)
model.load_state_dict(model_ckpt['model_state_dict'])
model.eval()
model.to(device)
pred_list = []
for step, (data) in enumerate(valid_loader):
img = data.to(device)
outputs = model(img).squeeze(1)
prediction = torch.max(outputs, 1)[1]
pred_list.append(prediction)
return pred_list
class DataFromH5File(torch_data.Dataset):
def __init__(self,data):
self.lr = data
def __getitem__(self, idx):
data = torch.from_numpy(self.lr[idx]).float()
return data
def __len__(self):
return self.lr.shape[0]
def test():
path_ckpt = "checkpoints/best_acc.pth.tar"
h5Filea = h5py.File("test/testa.h5", 'r')
h5Fileb = h5py.File("test/testb.h5", 'r')
test_data_a = DataFromH5File(np.array(h5Filea['data']))
test_loader_a = torch_data.DataLoader(test_data_a, batch_size=1,
shuffle=False, num_workers=4, pin_memory=False)
test_data_b = DataFromH5File(np.array(h5Fileb['data']))
test_loader_b = torch_data.DataLoader(test_data_b, batch_size=1,
shuffle=False, num_workers=4, pin_memory=False)
pred_a = validation(test_loader_a, path_ckpt)
pred_b = validation(test_loader_b, path_ckpt)
#绘制 pred
print("绘制a结果图片")
plt.figure()
plt.title("Forecast result testa.h5")
plt.scatter(np.array(range(1, len(pred_a)+1)), np.array(pred_a))
plt.xlabel("number")
plt.ylabel("category")
plt.savefig("Forecast result testa.h5.png")
plt.close()
print("绘制b结果图片")
plt.figure()
plt.title("Forecast result testb.h5")
plt.scatter(np.array(range(1, len(pred_b)+1)), np.array(pred_b))
plt.xlabel("number")
plt.ylabel("category")
plt.savefig("Forecast result testb.h5.png")
plt.close()
#将检测结果保存到csv
def writeCsva(File,species):
row = [File,species]
out = open("Forecast result testa.h5.csv", "a", newline="")
csv_writer = csv.writer(out, dialect="excel")
csv_writer.writerow(row)
def writeCsvb(File,species):
row = [File,species]
out = open("Forecast result testb.h5.csv", "a", newline="")
csv_writer = csv.writer(out, dialect="excel")
csv_writer.writerow(row)
print("保存a检测结果CSV")
writeCsva("number","category")
for nu in range(1, len(pred_a)+1):
ca = pred_a[nu-1]
writeCsva(nu,ca.item())
print("保存b检测结果CSV")
writeCsvb("number","category")
for nu in range(1, len(pred_b)+1):
ca = pred_b[nu-1]
writeCsvb(nu,ca.item())
if __name__=='__main__':
print("开始检测")
test()
print("检测结束")
————————————————————————————————
实验结果


希望获取数据集的朋友们可以到这两个网址下载后使用:
全部评论 (0)
还没有任何评论哟~
