Advertisement

深度学习第三站——图片分类任务(扩展半监督学习)

阅读量:

1、半监督具体实行过程

在半监督学习过程中,首先通过有标签的数据集对模型进行训练,并使其在测试集上的分类准确率达到0.6。随后将该模型应用于无标签数据集,并根据预测结果筛选出置信度较高的样本。对于这些分类结果,在预测概率超过0.99的情况下(即置信度高于99%),将这些新增的样本重新加入到原来的有标签数据集中用于进一步训练。

个人难以实现分类准确率达到0.6这一目标,并在代码实现中使用了ResNet模型作为特征提取器完成该过程。

2、增添的代码块

复制代码
 class semiDataset(Dataset):  # 评估得到的标签可靠性

    
     def __init__(self, no_label_loader, model, device, thres):  # 传入无标签参数,模型,gpu设备,置信度
    
     X, Y = self.data_pred(no_label_loader, model, device, thres)
    
     if X == []:
    
         self.flag = False
    
     else:
    
         self.flag = True
    
         self.X = np.array(X)  # 转为矩阵
    
         self.Y = torch.LongTensor(Y)
    
         self.transform = train_transform
    
  
    
     def data_pred(self, no_label_loader, model, device, thres):  # 打标签函数
    
     model = model.to(device)
    
     soft = nn.Softmax(dim=1)
    
     pred_prob = []  # 记录预测值
    
     labels = []  # 记录Y,标签
    
     x = []
    
     y = []
    
     with torch.no_grad():  # 非训练过程,不用计算梯度
    
         for data in no_label_loader:
    
             data = data[0].to(device)  # 取增广后的数据
    
             pred = model(data)
    
             pred_soft = soft(pred)  # 预测值->概率值
    
             pred_max, pred_value = pred_soft.max(1)  # 返回最大值以及对应下标
    
             pred_prob.extend(pred_max.cpu().numpy().tolist())  # 把全部概率存入数组  #append放一个数,extend放一组
    
             labels.extend(pred_value.cpu().numpy().tolist())  # 把分类结果存入数组
    
         for index, prob in enumerate(pred_prob):  # 枚举、列举的意思,用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据的下标和数据
    
             if prob > thres:
    
                 x.append(no_label_loader.dataset[index][1])  # 把对应的图片数据放入X中(未增广的原数据)
    
                 y.append(labels[index])  # x对应的标签
    
         return x, y
    
  
    
     def __getitem__(self, item):
    
     return self.transform(self.X[item]), self.Y[item]
    
  
    
     def __len__(self):
    
     return len(self.Y)
    
    
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-18/O3emA0IgZSPxtT8Wq2rLodi9UuGf.png)

本semiDataset类的核心功能是接收无监督学习的数据样本,并对其进行分类处理。通过预先设定的置信度阈值来判断每个样本是否需要标注。经过筛选后生成两个结果列表:一个是待标注的数据集X;另一个是与之对应的标签集合Y。

复制代码
 def get_semi_loader(no_label_loader, model, device, thres):

    
     semi_set = semiDataset(no_label_loader, model, device, thres)  # 经过半监督学习得到的训练集
    
     if semi_set.flag == False:
    
     return None
    
     semi_loader = DataLoader(semi_set, batch_size=16, shuffle=False)  # 一次取一批
    
     return semi_loader
    
    
    
    
    python

该get_semi_loader函数在半监督学习数据集不为空时会生成相应的semi_loader供器对象(每次迭代输出16组样本,并按固定顺序处理样本)

全部评论 (0)

还没有任何评论哟~