Advertisement

医学图像九分类

阅读量:

解决上一次九分类存在的问题
1.将所有图片名和图片的类别存在一个csv文件中,name对应图片的名字,label对应图片的标签。

复制代码
    def generate_csv(path,type,csv_path):
    with open(csv_path,'w',newline='') as csvfile:
        svwriter=csv.writer(csvfile,dialect="excel")
        svwriter.writerow(['name', 'label'])
        listdir=os.listdir(os.path.join(path,type))
        for i in listdir:
            classimage=os.path.join(path,type,i)#0,1,2,3,4
            data=os.listdir(classimage)
            for j in data:
                j=j.split('.')[0]
                svwriter.writerow([j,i])

2.读取csv文件中对应的图片。由于csv文件中保存的是图片名称和图片的label,我们需要读取到具体的图片,此时我们需要重写dataset方法,将图片名和label放在字典中,由图片地址获得图片名,再根据图片名去字典中找对应图片的label,返回图片名以及label:

复制代码
    class mydataset(Dataset):
    def __init__(self,data_folder,class_dict,transform=None):
        self.data_folder=data_folder#存放图片的文件夹train,test,val
        self.class_dict=class_dict #由csv得到的名,类别字典
        self.transform=transform
        self.imageclass=[s for s in os.listdir(data_folder)]
        self.imagelist=[]
        for i in self.imageclass:
            self.path=os.path.join(data_folder,i)
            self.ima=os.listdir(self.path)
            for er in self.ima:
                er=i+'/'+er
                self.imagelist.append(er)
        self.labels=[class_dict[i.split('.')[0].split('/')[1]] for i in self.imagelist]
    def __len__(self):
        return len(self.imagelist)
    def __getitem__(self, idx):
        image_path=os.path.join(self.data_folder,self.imagelist[idx])
        img=Image.open(image_path)
        imag=self.transform(img)
        img_name=self.imagelist[idx].split('.')[0].split('/')[1]
        label=self.class_dict[img_name]
        return img_name,imag,label

def init()方法用来初始化,我们必须实现def len()和def__getitem__()方法,getitem()方法每次返回一个图片,len()是整个的长度,需要调用getitem()多少次
由于class mydataset(Dataset)每次只能返回一个图片,因此我们需要使用dataloader来一次加载多张图片。

复制代码
    def data_prepare(path,type,batch_size,transform):
    csv_data=pd.read_csv(path+'/'+type+'.csv')
    class_dict={k:v for k,v in zip(csv_data.name,csv_data.label)}
    #print(class_dict)  # '名字':类别
    dataset=mydataset(data_folder=path+'/'+type,class_dict=class_dict,transform=transform)
    dataloader=DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
    
    return dataloader

使用resnet18进行迁移学习。
为了防止训练过拟合,引入了验证集。设置一个判断是否过拟合的条件,如果满足则不再训练,保存模型参数。进入测试阶段。

复制代码
    epoch=50
    for i in range(epoch):
    trainacc=train(model=model,Loss=loss,optimizer=optimizer,dataloader=train_loader)
    valacc=val(model=model,Loss=loss,optimizer=optimizer,dataloader=val_loader)
    
    if trainacc>0.9 and abs(trainacc-valacc)<=0.1:
        bestacc=valacc
        torch.save(model.state_dict(),'/home/cad429/code/yxy/Week2/codewek3/parameter.pkl')
        break

在训练和验证过程中为了计算每个类别对应的准确率,使用两个list,correct以及total,correct存放每个类别分类正确的数目,total每个类别的总数目。

使用pycm包来计算多分类的混淆矩阵。
完整代码:

复制代码
    import csv
    import pandas as pd
    import torch
    import os
    #将文件夹内的图片的名字存放在一个txt文件中。
    import os
    from pycm import *
    from torchvision import transforms,models
    import torch.nn as nn
    from torch.autograd import Variable
    from torch import optim
    from torch.utils.data import Dataset,DataLoader
    from PIL import Image
    import math
    from sklearn.metrics import confusion_matrix
    from tensorboardX import SummaryWriter
    path1=r"/home/cad429/code/yxy/Week2"
    # path2=r"/home/cad429/code/yue/Week2/train2"
    # path3=r"/home/cad429/code/yue/Week2/val"
    csv_path="/home/cad429/code/yxy/Week2/test.csv"
    csv_path2="/home/cad429/code/yxy/Week2/train2.csv"
    csv_path3="/home/cad429/code/yxy/Week2/val.csv"
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    batch_size=4
    transform = transforms.Compose([
    transforms.Resize(size=(227, 227)),
    transforms.RandomRotation(20),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
    ])
    bestacc=0
    def generate_csv(path,type,csv_path):
    with open(csv_path,'w',newline='') as csvfile:
        svwriter=csv.writer(csvfile,dialect="excel")
        svwriter.writerow(['name', 'label'])
        listdir=os.listdir(os.path.join(path,type))
        for i in listdir:
            classimage=os.path.join(path,type,i)#0,1,2,3,4
            data=os.listdir(classimage)
            for j in data:
                j=j.split('.')[0]
                svwriter.writerow([j,i])
    
    
    class mydataset(Dataset):
    def __init__(self,data_folder,class_dict,transform=None):
        self.data_folder=data_folder#存放图片的文件夹train,test,val
        self.class_dict=class_dict #由csv得到的名,类别字典
        self.transform=transform
        self.imageclass=[s for s in os.listdir(data_folder)]
        self.imagelist=[]
        for i in self.imageclass:
            self.path=os.path.join(data_folder,i)
            self.ima=os.listdir(self.path)
            for er in self.ima:
                er=i+'/'+er
                self.imagelist.append(er)
        self.labels=[class_dict[i.split('.')[0].split('/')[1]] for i in self.imagelist]
    def __len__(self):
        return len(self.imagelist)
    def __getitem__(self, idx):
        image_path=os.path.join(self.data_folder,self.imagelist[idx])
        img=Image.open(image_path)
        imag=self.transform(img)
        img_name=self.imagelist[idx].split('.')[0].split('/')[1]
        label=self.class_dict[img_name]
        return img_name,imag,label
    
    def data_prepare(path,type,batch_size,transform):
    csv_data=pd.read_csv(path+'/'+type+'.csv')
    class_dict={k:v for k,v in zip(csv_data.name,csv_data.label)}
    #print(class_dict)  # '名字':类别
    dataset=mydataset(data_folder=path+'/'+type,class_dict=class_dict,transform=transform)
    dataloader=DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
    
    return dataloader
    
    
    
    generate_csv(path=path1,type='test',csv_path=csv_path)
    generate_csv(path=path1,type='train2',csv_path=csv_path2)
    generate_csv(path=path1,type='val',csv_path=csv_path3)
    test_loader=data_prepare(path=path1,type='test',batch_size=batch_size,transform=transform)
    train_loader=data_prepare(path=path1,type='train2',batch_size=batch_size,transform=transform)
    val_loader=data_prepare(path=path1,type='val',batch_size=batch_size,transform=transform)
    
    
    model=models.resnet18(pretrained=True)
    for para in model.parameters():
    para.require_grad=False
    model.fc=nn.Linear(512,9)
    optimizer=optim.SGD(model.parameters(),lr=0.001)
    loss=nn.CrossEntropyLoss()
    
    # model.eval()
    # #model.cuda()
    # for i,(img_name,img,label) in enumerate(test_loader):
    #     img=Variable(img)
    #     label=Variable(label)
    #     print("i 是:",i)
    #     print(label)
    
    def train(model,Loss,optimizer,dataloader):
    model.train()
    model.cuda()
    s=0
    correcttrain = list(0 for i in range(9))
    totaltrain = list(0 for i in range(9))
    for i,(img_name,img,label) in enumerate(dataloader):
        img=Variable(img).cuda()
        label=Variable(label).cuda()
        optimizer.zero_grad()
        outputs=model(img)
        loss=Loss(outputs,label)
        loss.backward()
        optimizer.step()
        _,pred=torch.max(outputs,1)
        pred=pred.cpu().numpy()
        label=label.cpu().numpy()
        for x in range(len(pred)):
            a = pred[x]
            if pred[x] == label[x]:
                s+=1
                correcttrain[a] += 1
                totaltrain[a] += 1
            else:
                totaltrain[a] += 1
    
    trainloader_len=len(dataloader)
    acc=s/(trainloader_len*batch_size)
    print("train acc:",acc)
    return acc
    
    
    def val(model,Loss,optimizer,dataloader):
    model.cuda()
    model.eval()
    s1=0
    correctval = list(0 for i in range(9))
    totalval = list(0 for i in range(9))
    for i,(ima_name,img,label) in enumerate(dataloader):
        img=Variable(img).cuda()
        label=Variable(label).cuda()
        optimizer.zero_grad()
        output=model(img)
        loss = Loss(output, label)
        _,pred=torch.max(output,1)
        pred=pred.cpu().numpy()
        label=label.cpu().numpy()
    
        for j in range(len(pred)):
            a = pred[j]
            if pred[j] == label[j]:
                s1+=1
                correctval[a]+=1
                totalval[a]+=1
            else:
                totalval[a]+=1
    val_len=len(dataloader)
    valacc=s1/(val_len*batch_size)
    print("val acc :",valacc)
    for i in range(9):
        if totalval[i]!=0:
            print(i,"类的准确率是:",correctval[i]/totalval[i])
    return valacc
    
    epoch=50
    for i in range(epoch):
    trainacc=train(model=model,Loss=loss,optimizer=optimizer,dataloader=train_loader)
    valacc=val(model=model,Loss=loss,optimizer=optimizer,dataloader=val_loader)
    
    if trainacc>0.9 and abs(trainacc-valacc)<=0.1:
        bestacc=valacc
        torch.save(model.state_dict(),'/home/cad429/code/yxy/Week2/codewek3/parameter.pkl')
        break
    
    #test
    
    model.load_state_dict(torch.load('/home/cad429/code/yxy/Week2/codewek3/parameter.pkl'))
    model.cuda()
    model.eval()
    ss=0
    predtest=[]
    labeltest=[]
    for x,(img_name,img,label) in enumerate(test_loader):
    img=Variable(img).cuda()
    label=Variable(label).cuda()
    output=model(img)
    testloss=loss(output,label)
    optimizer.zero_grad()
    _,predd=torch.max(output,1)
    predd=predd.cpu().numpy()
    label=label.cpu().numpy()
    for i in range(len(predd)):
        if predd[i]==label[i]:
            ss+=1
        predtest.append(predd[i])
        labeltest.append(label[i])
    testlen=len(test_loader)
    testacc=ss/(testlen*batch_size)
    
    print("test acc is :",testacc)
    cm=ConfusionMatrix(actual_vector=labeltest,predict_vector=predtest)
    
    print(cm)

测试的结果:
1.总的准确率:
在这里插入图片描述
测试集的混淆矩阵:
Predict 0 1 2 3 4 5 6 7 8
Actual
0 127 0 0 0 0 0 2 1 0

1 0 0 2 0 0 1 0 0 0

2 0 0 36 0 0 3 0 0 0

3 0 0 3 0 1 1 0 0 0

4 1 0 0 0 24 0 0 10 0

5 0 0 0 0 1 90 0 1 0

6 3 0 0 0 3 0 98 1 0

7 1 0 0 0 5 2 0 54 1

8 3 0 0 0 0 0 0 0 90

AUC(Area under the ROC curve) 0类 0.97927
1类 0.5
2类 0.95679
3类 0.5
4类 0.83342
5类0.98173
6类 0.96449
7类 0.91562
8类 0.98281
F1score:
F1(F1 score - harmonic mean of precision and sensitivity)
0类0.95849 1类 0.0
2类 0.9
3类 0.0
4类 0.69565
5类 0.95238
6类 0.9561
7类 0.83077
8类 0.97826

全部评论 (0)

还没有任何评论哟~