Advertisement

深度学习-半监督学习的图片分类

阅读量:

相较于有监督的进行修改即可

food_Dataset

复制代码
    class food_Dataset(Dataset):
    def __init__(self,path,mode='train'):
        self.mode = mode
        if self.mode == 'semi':
            self.X = self.read_file(path)
        else:
            self.X, self.Y = self.read_file(path)
            self.Y = torch.LongTensor(self.Y)
        self.transform = train_forms
        if mode == 'train':
            self.transform = train_forms
        elif mode == 'valid':
            self.transform = valid_forms
    
    def read_file(self, path):
        if self.mode == 'semi':
            imgs_name = os.listdir(path)
            X = np.zeros((len(imgs_name), HW, HW, 3), dtype=np.uint8)
            for j, img_name in enumerate(imgs_name):
                img_path = os.path.join(path, img_name)
                img = Image.open(img_path)
                img = img.resize((HW, HW))
                X[j, ...] = img
            print(f"读到了{len(X)}个训练数据")
            return X
    
        else:
            for i in tqdm(range(11)):
                file_path = f"{path}\ {i:02d}"
                file_path = os.path.join(path, file_path)
                imgs_name = os.listdir(file_path)
                Xi = np.zeros((len(imgs_name), HW, HW, 3), dtype=np.uint8)
                yi = np.zeros((len(imgs_name)), dtype=np.uint8)
                for j, img_name in enumerate(imgs_name):
                    img_path = os.path.join(file_path, img_name)
                    img = Image.open(img_path)
                    img = img.resize((HW, HW))
                    Xi[j, ...] = img
                    yi[j] = i
                if i == 0:
                    X = Xi
                    y = yi
                else:
                    X = np.concatenate((X, Xi), axis=0)
                    y = np.concatenate((y, yi), axis=0)
            print(f"读到了{len(y)}个训练数据")
            return X, y
    
    
    def __getitem__(self, item):
        if self.mode == 'semi':
            return self.transform(self.X[item]),self.X[item]
        else:
            return self.transform(self.X[item]), self.Y[item]
    def __len__(self):
        return len(self.X)
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/lAhiUWNBIT9K5MQa13p8nvu20REO.png)
  1. 在ini中添加了
    if self.mode == ‘semi’:
    self.X = self.read_file(path)
    来判断是否使用无监督学习来加载数据,无监督学习之中,只加载了X,并没有y

  2. 将read_file函数添加到类当中,以方便调用。其中也需要判断mode的类型。

  3. __getitem__中,当mode为semi时,只需要返回经过变形的X和没有变形的X,其中因为在init中已经设置过了,因此transform的操作和valid是一样的。之所以要返回没有变形的X是因为后续需要将原始的x添加到原始数据集当中,统一操作。

  4. len ,因为,semi之中没有y,统一操作使用X

semiDataset

复制代码
    class semiDataset(Dataset):
    def __init__(self,semi_dataloader , model , device , thresh ):
        super().__init__()
        self.x,self.y = self.get_label(semi_dataloader,model,device,thresh)
        if self.x == []:
            self.flag = False
        else:
            self.flag = True
            self.X = np.array(self.x)
            self.Y = torch.LongTensor(self.y)
            self.transform = train_forms
    
    def get_label(self,semi_dataloader , model , device , thresh):
        model = model.to(device)
        soft = nn.Softmax(dim=1)
        pred_prob = []
        labels = []
        x = []
        y = []
        with torch.no_grad():
            for batch_x , _ in semi_dataloader:
                batch_x = batch_x.to(device)
                pred = model(batch_x)
                pred_soft = soft(pred)
                pred_max, pred_value = pred_soft.max(1)
                pred_prob.extend(pred_max.cpu().numpy().tolist())
                labels.extend(pred_value.cpu().numpy().tolist())
            for index, prob in enumerate(pred_prob):
                if prob > thresh:
                    x.append(semi_dataloader.dataset[index][1])
                    y.append(labels[index])
        return x ,y
    def __getitem__(self, item):
        return self.transform(self.X[item]),self.Y[item]
    def __len__(self):
        return len(self.X)
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/c05OZr1Ha8htuMl4FGiCdRYyozsw.png)

这个类的作用:
其主要功能是依据半监督学习中的阈值策略,从给定的半监督数据加载器(semi_dataloader)里筛选出预测概率超过指定阈值(thresh)的样本,进而构建一个新的数据集。

  1. get_label:将没有打标签的数据,通过model模型预测出一个y,当prob预测值大与阈值时,将x、y加入到列表当中。

  2. pred_soft = soft(pred) 将模型预测的pred,通过软分类获得每个类别的可能性
    pred_max, pred_value = pred_soft.max(1) 选取软分类当中最大的可能性,和类别
    pred_prob.extend(pred_max.cpu().numpy().tolist()) 将最大的可能性加入到列表当中
    labels.extend(pred_value.cpu().numpy().tolist()) 将最有可能的类别加入到列表当中。

  3. init :将get_label函数中得到数据放进来,设置一个flag表示当前是否有数据。

get_semiloader

复制代码
    def get_semiloader(no_label_dataloader , model ,device , thresh):
    semi_set = semiDataset(no_label_dataloader, model, device, thresh)
    if semi_set.flag == False:
        return None
    else:
        semi_loader = DataLoader(semi_set, batch_size=16, shuffle=False)
        return semi_set
        
    
    
    python
    
    

get_semiloader 函数的主要目的是根据传入的无标签数据加载器(no_label_dataloader)、模型(model)、设备(device)和阈值(thresh),创建一个半监督数据集(semiDataset 类的实例),并根据该数据集创建一个新的数据加载器。如果筛选后的数据集中没有符合条件的样本,则返回 None。

train_val

复制代码
    def train_val(model,train_loader,valid_loader,no_label_dataloader,device,epochs,optimizer,loss,thresh,save_path):
    model = model.to(device)
    semi_loader = None
    plt_train_loss = []
    plt_val_loss = []
    plt_train_accu = []
    plt_val_accu = []
    min_val_loss = 999999999999
    max_accu = 0
    total_train_setp = 0
    for epoch in range(epochs):
        startTime = time.time()
    
        train_loss = 0.0
        valid_loss = 0.0
    
        semi_loss = 0.0
        semi_acc = 0.0
    
        train_acc = 0.0
        valid_acc = 0.0
        print(f"--------正在进行第{epoch}次训练----------")
        model.train()
        # for i,data in enumerate(train_loader):
        #     batch_x , batch_y = data
        for batch_x, batch_y in train_loader:
            x = batch_x.to(device)
            target = batch_y.to(device)
            pred = model(x)
            train_bat_loss = loss(pred,target)
            train_bat_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # train_loss = train_bat_loss + train_loss
            train_loss += train_bat_loss.cpu().item()
            total_train_setp = total_train_setp + 1
            # if total_train_setp % 100 == 0:
            #     endTime = time.time()
            #     print(f"第{total_train_setp}次训练,损失值为:{train_bat_loss},花费时间:{endTime - startTime}")
            #     writer.add_scalar('loss',train_loss,total_train_setp)
            # print(f"训练损失值为:{train_loss}")
            epoch_accu = (pred.argmax(1) == target).sum().item()
            train_acc += epoch_accu
    
        if semi_loader != None:
            for batch_x, batch_y in semi_loader:
                x = batch_x.to(device)
                target = batch_y.to(device)
                pred = model(x)
                train_bat_loss = loss(pred,target)
                train_bat_loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                semi_loss += train_bat_loss.cpu().item()
                semi_epoch_accu = (pred.argmax(1) == target).sum().item()
                semi_acc += semi_epoch_accu
            print(f'半监督学习准确率:{semi_acc/len(semi_loader.dataset)}')
    
        plt_train_loss.append(train_loss / train_loader.__len__())
        plt_train_accu.append(train_acc/len(train_loader.dataset))
        writer.add_scalar('loss', train_loss / train_loader.__len__(), epoch)
        model.eval()
        with torch.no_grad():
            # for i, data in enumerate(valid_loader):
            #     batch_x, batch_y = data
            for batch_x, batch_y in valid_loader:
                x = batch_x.to(device)
                target = batch_y.to(device)
                pred = model(x)
                valid_bat_loss = loss(pred, target)
                valid_loss += valid_bat_loss.cpu().item()
                epoch_accu = (pred.argmax(1) == target).sum().item()
                valid_acc += epoch_accu
        plt_val_loss.append(valid_loss/valid_loader.__len__())
        valid_acc = valid_acc / len(valid_loader.dataset)
        plt_val_accu.append(valid_acc)
    
        if plt_val_accu[-1] > 0.01:
            semi_loader = get_semiloader(no_label_dataloader,model,device,thresh)
    
        if valid_acc > max_accu:
            max_accu = valid_acc
            torch.save(model,save_path)
        print(f"[{epoch}/{epochs}]:{time.time()-startTime} sec(s) trainloss:{plt_train_loss[-1]} validloss:{plt_val_loss[-1]}")
        print(f"[{epoch}/{epochs}]:{time.time()-startTime} sec(s) trainAccu:{plt_train_accu[-1]} validaccu:{plt_val_accu[-1]}")
    
    plt.plot(plt_train_loss)
    plt.plot(plt_val_loss)
    plt.legend(['train','valid'])
    plt.title("loss")
    plt.show()
    
    plt.plot(plt_train_accu)
    plt.plot(plt_val_accu)
    plt.legend(['train', 'valid'])
    plt.title("accu")
    plt.show()
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/vI28NeEixkGHKDYdRnqPsQ3zTUWy.png)
复制代码
    if semi_loader != None:

     for batch_x, batch_y in semi_loader:
         x = batch_x.to(device)
         target = batch_y.to(device)
         pred = model(x)
         train_bat_loss = loss(pred,target)
         train_bat_loss.backward()
         optimizer.step()
         optimizer.zero_grad()
         semi_loss += train_bat_loss.cpu().item()
         semi_epoch_accu = (pred.argmax(1) == target).sum().item()
         semi_acc += semi_epoch_accu
     print(f'半监督学习准确率:{semi_acc/len(semi_loader.dataset)}')
     
      先判断semi_loader是否为空,如果有数据的话,就利用刚才得到的数据进行训练。
    
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/5XB7efQUZgOldArnRju6EYHyChwt.png)
复制代码
    if plt_val_accu[-1] > 0.6:

     semi_loader = get_semiloader(no_label_dataloader,model,device,thresh)
     判断当前模型的预测能力,如果当前模型预测性能大与0.6的话,就调用get_semiloader,获得新数据。
    
    

完整代码

复制代码
    import torch
    import torch.nn as nn
    from torch.utils.data import Dataset,DataLoader
    import os
    import numpy as np
    from PIL import Image
    from tqdm import tqdm
    from torchvision import transforms
    import time
    import matplotlib.pyplot as plt
    from torch.utils.tensorboard import SummaryWriter
    import torchvision.models as models
    train_forms = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.RandomResizedCrop(244),
     transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     ]
    )
    
    valid_forms = transforms.Compose(
    [transforms.ToPILImage(),
     transforms.ToTensor(),
     ]
    )
    
    
    HW=224
    
    
    
    class food_Dataset(Dataset):
    def __init__(self,path,mode='train'):
        self.mode = mode
        if self.mode == 'semi':
            self.X = self.read_file(path)
        else:
            self.X, self.Y = self.read_file(path)
            self.Y = torch.LongTensor(self.Y)
        self.transform = train_forms
        if mode == 'train':
            self.transform = train_forms
        elif mode == 'valid':
            self.transform = valid_forms
    
    def read_file(self, path):
        if self.mode == 'semi':
            imgs_name = os.listdir(path)
            X = np.zeros((len(imgs_name), HW, HW, 3), dtype=np.uint8)
            for j, img_name in enumerate(imgs_name):
                img_path = os.path.join(path, img_name)
                img = Image.open(img_path)
                img = img.resize((HW, HW))
                X[j, ...] = img
            print(f"读到了{len(X)}个训练数据")
            return X
    
        else:
            for i in tqdm(range(11)):
                file_path = f"{path}\ {i:02d}"
                file_path = os.path.join(path, file_path)
                imgs_name = os.listdir(file_path)
                Xi = np.zeros((len(imgs_name), HW, HW, 3), dtype=np.uint8)
                yi = np.zeros((len(imgs_name)), dtype=np.uint8)
                for j, img_name in enumerate(imgs_name):
                    img_path = os.path.join(file_path, img_name)
                    img = Image.open(img_path)
                    img = img.resize((HW, HW))
                    Xi[j, ...] = img
                    yi[j] = i
                if i == 0:
                    X = Xi
                    y = yi
                else:
                    X = np.concatenate((X, Xi), axis=0)
                    y = np.concatenate((y, yi), axis=0)
            print(f"读到了{len(y)}个训练数据")
            return X, y
    
    
    def __getitem__(self, item):
        if self.mode == 'semi':
            return self.transform(self.X[item]),self.X[item]
        else:
            return self.transform(self.X[item]), self.Y[item]
    def __len__(self):
        return len(self.X)
    
    class myModule(nn.Module):
    def __init__(self , numclass):
        super().__init__()
        # 3*224*224 -> 512*7*7 -> 拉直 -> 全连接分类
        self.cn1 = nn.Conv2d(3, 64, 3, 1,1)  # 64*224*224
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.p1 = nn.MaxPool2d(2)   #64*112*112
    
        self.layer1 = nn.Sequential(
            nn.Conv2d(64, 128 , 3, 1, 1),  # 128*112*112
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),   #128*56*56
        )
    
        self.layer2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, 1, 1), # 256*56*56
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2),    #256*28*28
        )
    
        self.layer3 = nn.Sequential(
            nn.Conv2d(256,512,3,1,1), #512*28*28
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(2)  #512*14*14
        )
        self.layer4 = nn.Sequential(
            nn.MaxPool2d(2),  #512*7*7
            nn.Flatten(),
            nn.Linear(25088, 1000),
            nn.Linear(1000,numclass),
        )
    
    def forward(self , x):
        x = self.cn1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.p1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x
    
    
    class semiDataset(Dataset):
    def __init__(self,semi_dataloader , model , device , thresh ):
        super().__init__()
        self.x,self.y = self.get_label(semi_dataloader,model,device,thresh)
        if self.x == []:
            self.flag = False
        else:
            self.flag = True
            self.X = np.array(self.x)
            self.Y = torch.LongTensor(self.y)
            self.transform = train_forms
    
    def get_label(self,semi_dataloader , model , device , thresh):
        model = model.to(device)
        soft = nn.Softmax(dim=1)
        pred_prob = []
        labels = []
        x = []
        y = []
        with torch.no_grad():
            for batch_x , _ in semi_dataloader:
                batch_x = batch_x.to(device)
                pred = model(batch_x)
                pred_soft = soft(pred)
                pred_max, pred_value = pred_soft.max(1)
                pred_prob.extend(pred_max.cpu().numpy().tolist())
                labels.extend(pred_value.cpu().numpy().tolist())
            for index, prob in enumerate(pred_prob):
                if prob > thresh:
                    x.append(semi_dataloader.dataset[index][1])
                    y.append(labels[index])
        return x ,y
    def __getitem__(self, item):
        return self.transform(self.X[item]),self.Y[item]
    def __len__(self):
        return len(self.X)
    
    
    def get_semiloader(no_label_dataloader , model ,device , thresh):
    semi_set = semiDataset(no_label_dataloader, model, device, thresh)
    if semi_set.flag == False:
        return None
    else:
        semi_loader = DataLoader(semi_set, batch_size=16, shuffle=False)
        return semi_set
    def train_val(model,train_loader,valid_loader,no_label_dataloader,device,epochs,optimizer,loss,thresh,save_path):
    model = model.to(device)
    semi_loader = None
    plt_train_loss = []
    plt_val_loss = []
    plt_train_accu = []
    plt_val_accu = []
    min_val_loss = 999999999999
    max_accu = 0
    total_train_setp = 0
    for epoch in range(epochs):
        startTime = time.time()
    
        train_loss = 0.0
        valid_loss = 0.0
    
        semi_loss = 0.0
        semi_acc = 0.0
    
        train_acc = 0.0
        valid_acc = 0.0
        print(f"--------正在进行第{epoch}次训练----------")
        model.train()
        # for i,data in enumerate(train_loader):
        #     batch_x , batch_y = data
        for batch_x, batch_y in train_loader:
            x = batch_x.to(device)
            target = batch_y.to(device)
            pred = model(x)
            train_bat_loss = loss(pred,target)
            train_bat_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # train_loss = train_bat_loss + train_loss
            train_loss += train_bat_loss.cpu().item()
            total_train_setp = total_train_setp + 1
            # if total_train_setp % 100 == 0:
            #     endTime = time.time()
            #     print(f"第{total_train_setp}次训练,损失值为:{train_bat_loss},花费时间:{endTime - startTime}")
            #     writer.add_scalar('loss',train_loss,total_train_setp)
            # print(f"训练损失值为:{train_loss}")
            epoch_accu = (pred.argmax(1) == target).sum().item()
            train_acc += epoch_accu
    
        if semi_loader != None:
            for batch_x, batch_y in semi_loader:
                x = batch_x.to(device)
                target = batch_y.to(device)
                pred = model(x)
                train_bat_loss = loss(pred,target)
                train_bat_loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                semi_loss += train_bat_loss.cpu().item()
                semi_epoch_accu = (pred.argmax(1) == target).sum().item()
                semi_acc += semi_epoch_accu
            print(f'半监督学习准确率:{semi_acc/len(semi_loader.dataset)}')
    
        plt_train_loss.append(train_loss / train_loader.__len__())
        plt_train_accu.append(train_acc/len(train_loader.dataset))
        writer.add_scalar('loss', train_loss / train_loader.__len__(), epoch)
        model.eval()
        with torch.no_grad():
            # for i, data in enumerate(valid_loader):
            #     batch_x, batch_y = data
            for batch_x, batch_y in valid_loader:
                x = batch_x.to(device)
                target = batch_y.to(device)
                pred = model(x)
                valid_bat_loss = loss(pred, target)
                valid_loss += valid_bat_loss.cpu().item()
                epoch_accu = (pred.argmax(1) == target).sum().item()
                valid_acc += epoch_accu
        plt_val_loss.append(valid_loss/valid_loader.__len__())
        valid_acc = valid_acc / len(valid_loader.dataset)
        plt_val_accu.append(valid_acc)
    
        if plt_val_accu[-1] > 0.6:
            semi_loader = get_semiloader(no_label_dataloader,model,device,thresh)
    
        if valid_acc > max_accu:
            max_accu = valid_acc
            torch.save(model,save_path)
        print(f"[{epoch}/{epochs}]:{time.time()-startTime} sec(s) trainloss:{plt_train_loss[-1]} validloss:{plt_val_loss[-1]}")
        print(f"[{epoch}/{epochs}]:{time.time()-startTime} sec(s) trainAccu:{plt_train_accu[-1]} validaccu:{plt_val_accu[-1]}")
    
    plt.plot(plt_train_loss)
    plt.plot(plt_val_loss)
    plt.legend(['train','valid'])
    plt.title("loss")
    plt.show()
    
    plt.plot(plt_train_accu)
    plt.plot(plt_val_accu)
    plt.legend(['train', 'valid'])
    plt.title("accu")
    plt.show()
    
    
    now_path = os.getcwd()
    train_path_sample = 'food-11_sample\ training\ labeled'
    valid_path_sample = 'food-11_sample\ validation'
    no_label_path_sample = 'food-11_sample\ training\ unlabeled\ 00'
    
    train_path = 'food-11\ training\ labeled'
    valid_path = 'food-11\ validation'
    unlabeled_path = 'food-11\ validation'
    
    train_path = os.path.join(now_path,train_path_sample)
    valid_path = os.path.join(now_path,valid_path_sample)
    no_label_path = os.path.join(now_path , no_label_path_sample)
    
    train_dataset = food_Dataset(train_path,mode = "train")
    valid_dataset = food_Dataset(valid_path, mode= 'valid')
    no_label_dataset = food_Dataset(no_label_path, mode='semi')
    
    train_dataloader = DataLoader(train_dataset,batch_size=4,shuffle=True)
    valid_dataloader = DataLoader(valid_dataset,batch_size=4,shuffle=True)
    no_label_dataloader = DataLoader(no_label_dataset, batch_size=4, shuffle=False)
    
    writer = SummaryWriter("logs")
    
    numclass = 11
    model =models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, numclass)
    
    print(model)
    # model = myModule(11)
    lr = 0.001
    loss_fun = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr= lr ,weight_decay=1e-4)
    device = 'cuda' if torch.cuda.is_available() else "cpu"
    save_path = "best_module.pth"
    epochs = 3
    thresh = 0.9
    
    
    train_val(model,train_dataloader,valid_dataloader,no_label_dataloader,device,epochs,optimizer,loss_fun,thresh,save_path)
    
    
    
    #
    # for batch_x , batch_y in train_dataloader:
    #     pred = module(batch_x)
    #     print(pred)
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-07-13/kldNvAj2OKiC86t3RhcqTSrZ5GpJ.png)

全部评论 (0)

还没有任何评论哟~