Advertisement

医疗影像分类 | 阿尔兹海默症分类识别(2D+3D模型)(数据集为3D MRI扫描图像)

阅读量:

该项目旨在识别阿尔茨海默病(Alzheimer's Disease)患者的大脑头部3D MRI扫描图像中的三种类别:健康样本、轻度认知障碍样本和阿尔茨海默症样本。项目包含数据集准备、模型构建及训练等内容。
数据集
数据集为每人头部3D MRI扫描图像,包含三种类别共约900张图像。每个MRI序列由多个切片组成(长宽切片数量),单个 MRI sequence 为三维张量。
模型架构

  • 2D 模型基于 ResNet50 预训练网络进行设计,输入为 79x95x79 的三维张量。
  • 3D 模型基于 LeNet3D 网络架构设计,输入为三维张量 (1, 79, 95, 79)。
    训练与优化
  • 使用 Adam 优化器和交叉熵损失函数进行训练。
  • 在独立测试集中以尽量高的准确率区分三种类别。
  • 对于 2D 模型,在 training set 上分别达到约86% 的准确率;对于 3D 模型,在 validation set 上分别达到约84% 的准确率。
    实验结果与分析
  • 在独立测试集中分别以高准确率区分三种类别。
  • 绘制了 train loss 和 val loss 曲线,并保存了预测结果图和 CSV 文件供进一步分析。

阿尔兹海默症分类识别

    • 项目介绍
    • 训练集样式
    • 可视化数据集,保存为gif
    • ————————————————————————————————
  • 2D模型

  • 3D模型

    • 测试集样式
    • ————————————————————————————————

项目介绍

基于人体头部3D MRI扫描成像的数据集被划分为三个类别:分别为健康样本、轻度认知障碍样本以及阿尔茨海默症患者样本。研究团队通过利用该影像数据对模型算法进行训练,在独立测试集上尽可能提高分类准确性;每个样本均为三维体征的空间信息表示。

MRI 数据:每个 MRI sequence 包含大量切片构成的一个3D图像。该图像由长度、宽度以及切片数量构成。进而可知,单个的 MRI sequence 具备三个维度的信息:长度、宽度和厚度,因此可以作为一个三维张量。

训练集样式

可视化数据集,保存为gif

复制代码
    import h5py as h5
    from PIL import Image 
    import imageio
    
    train = h5.File('train/train_pre_data.h5','r')  # 读取数据
    one_sample = train['data'][0,0]
    
    frames = []
    for layer_img in one_sample:
    img = Image.fromarray(layer_img).convert('L')   # 先转换为image,再转为灰度图
    img.resize((79*5, 95*5),Image.ANTIALIAS).save('temp.jpg')   # 放大5倍并保存为temp.jpg
    frames.append(imageio.imread('temp.jpg'))   # 存入frame列表
    imageio.mimsave('{0}.gif'.format('idx'), frames, 'GIF', duration = 0.1) # 保存为gif格式

————————————————————————————————

2D模型

复制代码
    import os
    import h5py
    import numpy as np
    from keras.utils import np_utils
    import pandas as pd
    from keras.applications import resnet
    from sklearn.model_selection import train_test_split
    import tensorflow.keras as keras
    from keras.layers import Dense,GlobalAvgPool2D
    import matplotlib.pyplot as plt
    
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    train_dir = 'train'
    train_data = 'train_pre_data.h5'
    train_label = 'train_pre_label.csv'
    train_epochs = 100
    
    #读取训练数据
    train = h5py.File(os.path.join(train_dir,train_data),'r')
    #读取标签
    labels = pd.read_csv(os.path.join(train_dir,train_label))
    
    #将数据预处理,并且分为训练集和测试集
    features = np.array(train['data'])
    features = features.reshape(300,79,95,79)
    X_train, X_test, y_train, y_test = train_test_split(features,labels['label'].values,test_size = 0.3,random_state = 42)
    #对标签分为三类,进行独热码处理
    y_train = np_utils.to_categorical(y_train,num_classes=3)
    y_test = np_utils.to_categorical(y_test,num_classes=3)
    
    #神经网络,用ImageNet ResNet50预训练模型
    num_classes = 3
    inputdim = (79,95,79)
    base_model =resnet.ResNet50(include_top=False, weights = None, input_shape = inputdim)
    
    x = base_model.output
    #GlobalAvgPool2D是将输入特征图的每一个通道求平均得到一个数值。
    x = GlobalAvgPool2D()(x)
    
    #三个全连接层
    x = Dense(64,activation='relu')(x)
    x = Dense(32,activation='relu')(x)
    x = Dense(num_classes,activation='softmax')(x)
    
    model = keras.Model(inputs = base_model.input, outputs =x)
    print(model)
    model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
    model.summary()
    
    print('Training---------')
    # #保存模型设置
    # #
    # checkpointer = keras.callbacks.ModelCheckpoint(os.path.join("", 'model_{epoch:03d}.hdf5'),
    #                                    verbose=1, save_weights_only=False, period=train_epochs)
    
    history = model.fit(X_train,y_train, epochs = train_epochs,batch_size=32)
    # print(history.history)
    #绘制 train loss
    plt.figure()
    plt.plot(history.epoch, history.history['loss'], label = 'loss')
    plt.legend()
    plt.savefig("train_loss.png")
    plt.close()
    #绘制 accuracy
    plt.figure()
    plt.plot(history.epoch, history.history['accuracy'], label = 'Accuracy')
    plt.legend()
    plt.savefig("Accuracy.png")
    plt.close()
    
    
    print('\nTesting---------')
    loss,accuracy = model.evaluate(X_test,y_test)
    
    print('\ntest loss',loss)
    print('\ntest accuracy',accuracy)

3D模型

model.py

复制代码
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.nn.functional as F
    from torch.autograd import Variable
    
    class LeNet3D(nn.Module):
    def __init__(self,num_classes=3):
        super(LeNet3D, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1)
        self.pool1 = nn.MaxPool3d(2, 2)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
        self.pool2 = nn.MaxPool3d(2, 2)
        self.fc1 = nn.Linear(4800, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):  
        # print(x.size())         
        # torch.Size([16, 1, 79, 95, 79]) 
        x = F.relu(self.conv1(x))    # torch.Size([16, 16, 40, 48, 40])
        # print(x.size())
        x = self.pool1(x)            # torch.Size([16, 16, 20, 24, 20])
        # print(x.size())
        x = F.relu(self.conv2(x))    # torch.Size([16, 32, 10, 12, 10])
        # print(x.size())
        x = self.pool2(x)            # torch.Size([16, 32, 5, 6, 5])
        # print(x.size())
        x = x.view(x.size(0), -1)    # torch.Size([16, 4800])
        # print(x.size())
        x = F.relu(self.fc1(x))      # torch.Size([16, 120])
        # print(x.size())
        x = F.relu(self.fc2(x))      # torch.Size([16, 84])
        # print(x.size())
        x = self.fc3(x)              # output(3)
        
        return x
    
    def main_3d():
    model = LeNet3D(num_classes = 3)
    model = nn.DataParallel(model, device_ids=None)
    print(model)
    input_var = Variable(torch.randn(16, 1, 32, 64, 64))  # b,c,z,h,w
    output = model(input_var)
    print(output.shape)

train.py

复制代码
    import pandas as pd
    import torch
    from torch.utils import data as torch_data
    import torch.nn.functional as F
    from tensorboardX import SummaryWriter
    from models import LeNet3D
    import h5py
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    def validation(valid_loader, path_ckpt):
    model = LeNet3D()
    model_ckpt = torch.load(path_ckpt)
    model.load_state_dict(model_ckpt['model_state_dict'])
    # model = torch.nn.DataParallel(model).to(device)
    model.eval()
    model.to(device)
    
    loss_sum = 0
    acc_sum = 0
    
    for step, (data, label) in enumerate(valid_loader):
    
        img = data.to(device)
        # print(img.shape)
        targets = label.to(device)
        outputs = model(img).squeeze(1)
    
        loss = F.cross_entropy(outputs, torch.max(targets, 1)[1]).to(device)
    
        loss_sum += loss.detach().item()
    
        prediction = torch.max(outputs, 1)[1]
        pred_y = prediction.data.cpu().numpy()
    
        target = torch.max(targets, 1)[1]
        target_y = target.data.cpu().numpy()
    
        acc_sum += sum((pred_y-target_y)==0)
    
    loss_avg = loss_sum / len(valid_loader)
    return loss_avg, acc_sum
    
    
    class DataFromH5CSVFile(torch_data.Dataset):
    def __init__(self,data,label):
        self.hr = label
        self.lr = data
        
    def __getitem__(self, idx):
        if self.hr[idx] == 0:
            label = torch.from_numpy(np.array([1,0,0])).float()
        elif self.hr[idx] == 1:
            label = torch.from_numpy(np.array([0,1,0])).float()
        else:
            label = torch.from_numpy(np.array([0,0,1])).float()    
        data = torch.from_numpy(self.lr[idx]).float()
        return data, label
    
    def __len__(self):
        assert self.hr.shape[0] == self.lr.shape[0], "Wrong data length"
        return self.hr.shape[0]
    
    def train():
    MAX_EPOCH = 100
    ITR_PER_CKPT_VAL = 1
    train_loss = []
    val_acc = []
    val_loss = []
    h5File = h5py.File("train/train_pre_data.h5", 'r') 
    labels = pd.read_csv(os.path.join("train/train_pre_label.csv"))
    train_data = DataFromH5CSVFile(np.array(h5File['data'][:250]), np.array(labels['label'].values[:250]))
    print("train_data:",len(train_data))
    valid_data = DataFromH5CSVFile(np.array(h5File['data'][200:]), np.array(labels['label'].values[200:]))
    print("valid_data:",len(valid_data))
    train_loader = torch_data.DataLoader(train_data, batch_size=64,
                                        shuffle=True, num_workers=4, pin_memory=False)
    valid_loader = torch_data.DataLoader(valid_data, batch_size=1,
                                        shuffle=False, num_workers=4, pin_memory=False)
       
    model = LeNet3D()
    # model = torch.nn.DataParallel(model).to(device)
    model.train()
    model.to(device)
    print(model)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
    
    best_valid_score = 0
    writer = SummaryWriter(comment='Linear')
    
    for i_epoch in range(1, MAX_EPOCH + 1):
        loss_sum = 0
        N = 0
        for step, (data, label) in enumerate(train_loader):
            img = data.to(device)
            targets = label.to(device)
            outputs = model(img).squeeze(1)
            loss = F.cross_entropy(outputs, torch.max(targets, 1)[1]).to(device)
            loss_sum += loss.detach().item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        loss_avg = loss_sum / len(train_loader)
        print("[Epoch " + str(i_epoch) + " | " + "train loss = " + ("%.7f" % loss_avg) + "]")
        writer.add_scalar('scalar/train_loss', loss_avg, i_epoch)
        train_loss.append(loss_avg)
    
        if i_epoch % ITR_PER_CKPT_VAL == 0:
            # Saving checkpoint.
            path_ckpt = r"checkpoints/" + str(i_epoch) + ".pth.tar"
            torch.save({"epoch": i_epoch, "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict()}, path_ckpt)
    
            loss_val, acc_sum = validation(valid_loader, path_ckpt)
            accuracy = acc_sum * 100 / len(valid_loader)
    
            print("[Epoch " + str(i_epoch) + " | " + "val loss = " + ("%.7f" % loss_val) + "  accuracy = " + ("%.3f" % accuracy) + "%]")
            writer.add_scalar('scalar/val_loss', loss_val, i_epoch)
            writer.add_scalar('scalar/val_acc', accuracy, i_epoch)
            val_acc.append(accuracy)
            val_loss.append(loss_val)
    
            if best_valid_score < accuracy:
                path_ckpt_best = r"checkpoints/best_acc.pth.tar"
                torch.save({"epoch": i_epoch, "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict()}, path_ckpt_best)
                best_valid_score = accuracy
    writer.close()
    #绘制 train loss
    plt.figure()
    plt.plot(range(1, MAX_EPOCH + 1), train_loss, label = 'train_loss')
    plt.plot(range(1, MAX_EPOCH + 1), val_loss, label = 'val_loss')
    plt.legend()
    plt.savefig("train_loss.png")
    plt.close()
    #绘制 accuracy
    plt.figure()
    plt.plot(range(1, MAX_EPOCH + 1), val_acc, label = 'val_accuracy')
    plt.legend()
    plt.savefig("val_accuracy.png")
    plt.close()
    print("best_valid_score:", best_valid_score)
    
    if __name__=='__main__':
    train()

测试集样式

测试集样式

test.py

复制代码
    import pandas as pd
    import torch
    from torch.utils import data as torch_data
    import torch.nn.functional as F
    from tensorboardX import SummaryWriter
    from models import LeNet3D
    import h5py
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    import csv
    os.environ["CUDA_VISIBLE_DEVICES"] = "2"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    def validation(valid_loader, path_ckpt):
    model = LeNet3D()
    model_ckpt = torch.load(path_ckpt)
    model.load_state_dict(model_ckpt['model_state_dict'])
    model.eval()
    model.to(device)
    pred_list = []
    for step, (data) in enumerate(valid_loader):
        img = data.to(device)
        outputs = model(img).squeeze(1)
        prediction = torch.max(outputs, 1)[1]
        pred_list.append(prediction)
    return pred_list
    
    
    class DataFromH5File(torch_data.Dataset):
    def __init__(self,data):
        self.lr = data
    def __getitem__(self, idx):
        data = torch.from_numpy(self.lr[idx]).float()
        return data
    def __len__(self):
        return self.lr.shape[0]
    
    def test():
    path_ckpt = "checkpoints/best_acc.pth.tar"
    h5Filea = h5py.File("test/testa.h5", 'r') 
    h5Fileb = h5py.File("test/testb.h5", 'r') 
    test_data_a = DataFromH5File(np.array(h5Filea['data']))
    test_loader_a = torch_data.DataLoader(test_data_a, batch_size=1,
                                        shuffle=False, num_workers=4, pin_memory=False)
    test_data_b = DataFromH5File(np.array(h5Fileb['data']))
    test_loader_b = torch_data.DataLoader(test_data_b, batch_size=1,
                                        shuffle=False, num_workers=4, pin_memory=False)
    
    pred_a = validation(test_loader_a, path_ckpt)
    pred_b = validation(test_loader_b, path_ckpt)
    
    
    #绘制 pred
    print("绘制a结果图片")
    plt.figure()
    plt.title("Forecast result testa.h5")
    plt.scatter(np.array(range(1, len(pred_a)+1)), np.array(pred_a))
    plt.xlabel("number")
    plt.ylabel("category")
    plt.savefig("Forecast result testa.h5.png")
    plt.close()
    
    print("绘制b结果图片")
    plt.figure()
    plt.title("Forecast result testb.h5")
    plt.scatter(np.array(range(1, len(pred_b)+1)), np.array(pred_b))
    plt.xlabel("number")
    plt.ylabel("category")
    plt.savefig("Forecast result testb.h5.png")
    plt.close()
    
    #将检测结果保存到csv
    def writeCsva(File,species):
        row = [File,species]
        out = open("Forecast result testa.h5.csv", "a", newline="")
        csv_writer = csv.writer(out, dialect="excel")
        csv_writer.writerow(row)
    def writeCsvb(File,species):
        row = [File,species]
        out = open("Forecast result testb.h5.csv", "a", newline="")
        csv_writer = csv.writer(out, dialect="excel")
        csv_writer.writerow(row)
    print("保存a检测结果CSV")    
    writeCsva("number","category")
    for nu in range(1, len(pred_a)+1):
        ca = pred_a[nu-1]
        writeCsva(nu,ca.item())
    print("保存b检测结果CSV")    
    writeCsvb("number","category")
    for nu in range(1, len(pred_b)+1):
        ca = pred_b[nu-1]
        writeCsvb(nu,ca.item())
    
    if __name__=='__main__':
    print("开始检测")
    test()
    print("检测结束")

————————————————————————————————

实验结果

在这里插入图片描述
在这里插入图片描述

希望获取数据集的朋友们可以到这两个网址下载后使用:

全部评论 (0)

还没有任何评论哟~