Advertisement

半监督学习

阅读量:

1. 半监督学习的基本概念

1.1 什么是半监督学习?

半监督学习(Semi-Supervised Learning, SSL)是一种机器学习方法,结合了少量标注数据 (labeled data)和大量未标注数据 (unlabeled data)来训练模型。它介于监督学习(所有数据都有标签)和无监督学习(所有数据无标签)之间,旨在通过未标注数据的潜在结构信息提升模型的泛化能力。

  • 监督学习 :数据集 D={(xi,yi)}i=1nD = {(x_i, y_i)}_{i=1}^n,每个样本 xix_i 都有标签 yiy_i。
  • 无监督学习 :数据集 D={xi}i=1nD = {x_i}_{i=1}^n,没有标签,通常用于聚类或降维。
  • 半监督学习 :数据集 D=Dl∪DuD = D_l \cup D_u,其中:
    • 标注数据集 Dl={(xi,yi)}i=1lD_l = {(x_i, y_i)}_{i=1}^l,数量少(ll 小)。
    • 未标注数据集 Du={xj}j=l+1l+uD_u = {x_j}_{j=l+1}^{l+u},数量多(u≫lu \gg l)。

目标 :学习一个模型 f:X→Yf: X \to Y,利用 DlD_l 和 DuD_u 最小化测试误差。

1.2 为什么需要半监督学习?

  • 标注成本高 :标注数据需要大量人力、时间和专业知识。例如,医学影像标注需要医生,文本情感分析需要语言专家。
  • 未标注数据丰富 :未标注数据易于获取,如网络图片、用户日志、传感器数据。
  • 性能提升 :半监督学习能在标注数据稀缺时显著提高模型性能,接近甚至超越全监督学习。

1.3 半监督学习的基本假设

半监督学习的有效性依赖于以下假设:

  1. 平滑性假设 :如果 x1,x2x_1, x_2 在特征空间中接近(∣∣x1−x2∣∣||x_1 - x_2|| 小),则标签 y1≈y2y_1 \approx y_2。
  2. 簇假设 :数据点形成簇,同一簇内的点通常属于同一类别。
  3. 流形假设 :高维数据位于低维流形上,未标注数据帮助学习流形结构。
  4. 低密度分离假设 :决策边界应通过低密度区域,避免穿过高密度区域。

这些假设确保未标注数据能提供有用的信息。


2. 半监督学习的数学框架

半监督学习的数学框架通过联合优化监督损失和无监督损失,利用标注数据和未标注数据的特性来训练模型。以下是详细的数学形式化描述和推导。

2.1 数据定义

  • 标注数据 :Dl={(xi,yi)}i=1lD_l = {(x_i, y_i)}_{i=1}^l,其中 xi∈Rdx_i \in \mathbb{R}^d 是特征向量,yi∈Yy_i \in Y 是标签(分类任务中 Y={1,2,…,C}Y = {1, 2, \dots, C},回归任务中 Y⊆RY \subseteq \mathbb{R})。
  • 未标注数据 :Du={xj}j=l+1l+uD_u = {x_j}_{j=l+1}^{l+u},其中 xj∈Rdx_j \in \mathbb{R}^d。
  • 总数据集:D=Dl∪DuD = D_l \cup D_u,通常 u≫lu \gg l。
  • 测试数据:Dtest={(xk,yk)}k=1mD_{\text{test}} = {(x_k, y_k)}_{k=1}^m,用于评估模型。

2.2 目标

学习一个映射 f:X→Yf: X \to Y,参数化为 fθf_\theta,通过优化以下目标函数最小化泛化误差:

min⁡θE(x,y)∼p(x,y)[ℓ(fθ(x),y)]\min_\theta \mathbb{E}{(x, y) \sim p(x, y)} [ \ell(f\theta(x), y) ]

其中,ℓ\ell 是损失函数(如交叉熵或均方误差),p(x,y)p(x, y) 是真实数据分布。由于 p(x,y)p(x, y) 未知,半监督学习通过 DlD_l 和 DuD_u 近似优化。

2.3 损失函数

半监督学习的损失函数通常由两部分组成:

L(θ)=Ll(θ)+λLu(θ)L(\theta) = L_l(\theta) + \lambda L_u(\theta)

  • 监督损失 LlL_l:基于标注数据的损失,衡量模型在 DlD_l 上的预测误差。
  • 无监督损失 LuL_u:基于未标注数据的损失,捕捉数据分布或结构信息。
  • 超参数 λ\lambda:平衡监督和无监督损失的权重。

2.3.1 监督损失

对于分类任务,监督损失通常是交叉熵损失:

Ll(θ)=1l∑i=1lCE(fθ(xi),yi)L_l(\theta) = \frac{1}{l} \sum_{i=1}^l \text{CE}(f_\theta(x_i), y_i)

其中,交叉熵定义为:

CE(fθ(xi),yi)=−∑c=1C1(yi=c)log⁡pθ(yi=c∣xi)\text{CE}(f_\theta(x_i), y_i) = - \sum_{c=1}^C \mathbb{1}(y_i = c) \log p_\theta(y_i = c | x_i)

其中,pθ(yi=c∣xi)=fθ(xi)cp_\theta(y_i = c | x_i) = f_\theta(x_i)_c 是模型预测的类别概率,1\mathbb{1} 是指示函数。

对于回归任务,监督损失可能是均方误差:

Ll(θ)=1l∑i=1l(fθ(xi)−yi)2L_l(\theta) = \frac{1}{l} \sum_{i=1}^l (f_\theta(x_i) - y_i)^2

2.3.2 无监督损失

无监督损失的设计依赖于算法和假设,常见形式包括:

一致性正则化
要求模型对未标注数据的不同扰动版本(例如通过数据增强生成的样本)产生一致的预测结果,从而增强模型的鲁棒性和泛化能力。这种方法假设模型对同一数据的不同增强版本应具有相似的输出,常用于图像、文本等数据。损失函数通过度量增强样本预测之间的差异来实现约束。

Lu(θ)=1u∑j=l+1l+u∣∣fθ(xj)−fθ(x~j)∣∣2L_u(\theta) = \frac{1}{u} \sum_{j=l+1}^{l+u} || f_\theta(x_j) - f_\theta(\tilde{x}_j) ||^2

其中,x~j=Augment(xj)\tilde{x}_j = \text{Augment}(x_j) 是 xjx_j 的增强版本(如图像的旋转、翻转、颜色抖动),∣∣⋅∣∣2|| \cdot ||^2 是均方误差或其他距离度量(如 KL 散度)。

示例 :在图像分类任务中,对一张未标注的猫咪图片 xjx_j,生成其旋转 90° 的版本 x~j\tilde{x}_j。模型对这两张图片的预测概率分布应尽可能接近,损失函数计算两者预测的均方误差以鼓励一致性。

熵最小化
鼓励模型对未标注数据的预测具有高置信度,即输出概率分布应尽量尖锐(低熵)。这种方法基于假设:模型对未标注数据的预测应倾向于某一个类别,而不是均匀分布,从而减少决策边界的不确定性。熵最小化常用于半监督学习中以提高模型的分类确定性。

Lu(θ)=1u∑j=l+1l+uH(fθ(xj))L_u(\theta) = \frac{1}{u} \sum_{j=l+1}^{l+u} H(f_\theta(x_j))

其中,H(p)=−∑cp(c)log⁡p(c)H(p) = - \sum_c p(c) \log p(c) 是熵,fθ(xj)f_\theta(x_j) 是预测概率分布。

示例 :在一组未标注的手写数字图片中,模型对某张图片预测的概率分布为 [0.9, 0.05, 0.05](高置信度),熵较低;若预测为 [0.4, 0.3, 0.3](低置信度),熵较高。损失函数鼓励前者以提高模型的确定性。

伪标签损失
为未标注数据生成伪标签,并将其作为监督信号进行训练。伪标签通过模型当前预测的最大概率类别生成,但通常只对高置信度预测生成伪标签,以避免错误标签的负面影响。这种方法利用了模型自身的预测能力,逐步提高未标注数据的利用率。

Lu(θ)=1u∑j=l+1l+u1(max⁡cfθ(xj)c>τ)⋅CE(fθ(xj),y^j)L_u(\theta) = \frac{1}{u} \sum_{j=l+1}^{l+u} \mathbb{1}(\max_c f_\theta(x_j)c > \tau) \cdot \text{CE}(f\theta(x_j), \hat{y}_j)

其中,τ\tau 是置信度阈值(如 0.95),仅使用高置信度的伪标签,y^j=arg⁡max⁡cfθ(xj)c\hat{y}j = \arg\max_c f\theta(x_j)_c 是伪标签,CE\text{CE} 是交叉熵损失。

示例 :在文本分类任务中,模型对一条未标注的评论预测类别概率为 [0.96, 0.03, 0.01],超过阈值 τ=0.95\tau=0.95,则生成伪标签“正面”,并用交叉熵损失训练模型以强化该预测。

图正则化
基于数据图的平滑性约束,假设在数据空间中相似的数据点(例如特征或几何距离接近的点)应具有相似的预测结果。通过构建数据点之间的相似性图,损失函数鼓励模型在图上的预测输出平滑一致,常用于结构化数据的半监督学习。

Lu(θ)=∑i,jwij∣∣fθ(xi)−fθ(xj)∣∣2L_u(\theta) = \sum_{i,j} w_{ij} || f_\theta(x_i) - f_\theta(x_j) ||^2

其中,wijw_{ij} 是数据点 xi,xjx_i, x_j 之间的相似性权重(例如基于高斯核的距离计算)。

示例 :在社交网络用户分类任务中,两个用户 xix_i 和 xjx_j 的特征(兴趣、交互记录)相似,相似性权重 wijw_{ij} 较高。图正则化损失鼓励模型对这两名用户的预测类别(如“活跃用户”)保持一致。

2.3.3 联合优化

联合优化旨在同时优化有监督损失和无监督损失,以充分利用标注数据和未标注数据的学习信号。其目标是将有监督的交叉熵损失(针对标注数据)与无监督损失(针对未标注数据)结合,通过权衡两者来提升模型的泛化性能。联合优化目标为:

min⁡θ[1l∑i=1lCE(fθ(xi),yi)+λ1u∑j=l+1l+uℓu(fθ(xj),fθ(x~j))]\min_\theta \left[ \frac{1}{l} \sum_{i=1}^l \text{CE}(f_\theta(x_i), y_i) + \lambda \frac{1}{u} \sum_{j=l+1}^{l+u} \ell_u(f_\theta(x_j), f_\theta(\tilde{x}_j)) \right]

其中,第一项 1l∑i=1lCE(fθ(xi),yi)\frac{1}{l} \sum_{i=1}^l \text{CE}(f_\theta(x_i), y_i) 是标注数据的交叉熵损失,衡量模型预测与真实标签的差异;第二项 1u∑j=l+1l+uℓu(fθ(xj),fθ(x~j))\frac{1}{u} \sum_{j=l+1}^{l+u} \ell_u(f_\theta(x_j), f_\theta(\tilde{x}_j)) 是无监督损失(如一致性正则化的均方误差),利用未标注数据的增强版本(如旋转、翻转等)来增强模型鲁棒性。超参数 λ\lambda 控制无监督损失的权重,平衡两者的贡献。合适的 λ\lambda 值通常通过实验调参确定,以避免过分偏向任一损失。

优化通常通过梯度下降或其变体(如 Adam 优化器)完成。联合损失的梯度为有监督损失梯度和无监督损失梯度的加权和:

∇θL=∇θLl+λ∇θLu\nabla_\theta L = \nabla_\theta L_l + \lambda \nabla_\theta L_u

其中,∇θLl\nabla_\theta L_l 是标注数据的有监督损失梯度,∇θLu\nabla_\theta L_u 是未标注数据的无监督损失梯度。梯度下降通过迭代更新参数 θ\theta,使联合损失逐步减小。在实践中,为稳定训练,常采用学习率调度或梯度裁剪等技术。此外,无监督损失的权重 λ\lambda 可设计为动态变化(如随训练轮数逐渐增加),以在训练初期更多依赖有监督信号,后期逐步引入未标注数据的贡献。

示例 :在图像分类任务中,假设有 100 张标注图片和 1000 张未标注图片。联合优化目标结合标注图片的交叉熵损失(预测类别与真实标签的差异)和未标注图片的一致性损失(原图与旋转版本预测的均方误差)。通过设置 λ=0.5\lambda=0.5 并使用 Adam 优化器,模型在每次迭代中计算两部分损失的梯度和,加权后更新参数,从而同时学习标注和未标注数据的信息。

2.4 数学推导示例:一致性正则化

以一致性正则化为例,假设无监督损失为:

Lu(θ)=1u∑j=l+1l+u∣∣fθ(xj)−fθ(x~j)∣∣2L_u(\theta) = \frac{1}{u} \sum_{j=l+1}^{l+u} || f_\theta(x_j) - f_\theta(\tilde{x}_j) ||^2

推导梯度:

对于单个未标注样本 xjx_j,损失为:

ℓu=∣∣fθ(xj)−fθ(xj)∣∣2=∑c=1C(fθ(xj)c−fθ(xj)c)2\ell_u = || f_\theta(x_j) - f_\theta(\tilde{x}j) ||^2 = \sum{c=1}^C (f_\theta(x_j)c - f\theta(\tilde{x}_j)_c)^2

对参数 θ\theta 求梯度:

∂ℓu∂θ=∑c=1C2(fθ(xj)c−fθ(xj)c)(∂fθ(xj)c∂θ−∂fθ(xj)c∂θ)\frac{\partial \ell_u}{\partial \theta} = \sum_{c=1}^C 2 (f_\theta(x_j)c - f\theta(\tilde{x}_j)c) \left( \frac{\partial f\theta(x_j)c}{\partial \theta} - \frac{\partial f\theta(\tilde{x}_j)_c}{\partial \theta} \right)

使用链式法则,计算 ∂fθ(xj)c∂θ\frac{\partial f_\theta(x_j)_c}{\partial \theta},这通常通过深度学习框架的自动求导完成。

通过这种方式,未标注数据通过一致性约束影响模型参数更新。

2.5 正则化与稳定性

  • 正则化 :无监督损失(如一致性正则化或熵最小化)起到正则化作用,防止模型过拟合少量标注数据。
  • 稳定性 :未标注数据的噪声可能导致优化不稳定,因此需要:
    • 置信度阈值(如伪标签中的 τ\tau)。
    • 平滑技术(如教师模型的指数移动平均)。
    • 权衡参数 λ\lambda 的动态调整。

3. 半监督学习的主要方法

以下是半监督学习的主要方法,包含详细的算法描述、伪代码、优缺点分析和适用场景。

3.1 自训练(Self-Training)

思想 :用标注数据训练初始模型,为未标注数据生成伪标签,然后将伪标签数据加入训练集,迭代优化。

算法步骤

  1. 初始化:用 DlD_l 训练模型 fθf_\theta。
  2. 伪标签生成:对 DuD_u 中的每个 xjx_j,预测 y^j=arg⁡max⁡cfθ(xj)c\hat{y}j = \arg\max_c f\theta(x_j)_c。
  3. 选择高置信度伪标签:若 max⁡cfθ(xj)c>τ\max_c f_\theta(x_j)_c > \tau,将 (xj,y^j)(x_j, \hat{y}_j) 加入 DlD_l。
  4. 重新训练:用更新后的 DlD_l 训练 fθf_\theta。
  5. 重复 2-4 直到收敛或达到最大迭代次数。

伪代码

复制代码
    def self_training(D_l, D_u, model, threshold, max_iterations):
    for t in range(max_iterations):
        # 训练模型
        model.fit(D_l)
        # 为未标注数据生成伪标签
        pseudo_labels = []
        for x in D_u:
            pred = model.predict(x)
            prob = max(softmax(pred))
            if prob > threshold:
                pseudo_labels.append((x, argmax(pred)))
        # 更新训练集
        D_l.extend(pseudo_labels)
        if len(pseudo_labels) == 0:
            break
    return model
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-15/FxVSDlqwQEk4tju0XWnds3T8ByGP.png)

数学表达

  • 监督损失:Ll=1l∑i=1lCE(fθ(xi),yi)L_l = \frac{1}{l} \sum_{i=1}^l \text{CE}(f_\theta(x_i), y_i)。
  • 伪标签损失:Lu=1u′∑(xj,yj)∈Du′CE(fθ(xj),yj)L_u = \frac{1}{u'} \sum_{(x_j, \hat{y}j) \in D_u'} \text{CE}(f\theta(x_j), \hat{y}_j),其中 Du′D_u' 是高置信度伪标签集。

优点

  • 简单易实现,适用于大多数监督学习模型。
  • 逐步扩充训练集,适合小规模标注数据。

缺点

  • 伪标签噪声累积,可能导致模型偏差。
  • 对初始模型性能敏感,初始误差会传播。

适用场景

  • 数据标注成本高,但未标注数据丰富。
  • 模型预测置信度较高(如图像分类中的CNN)。

例子

  • 在CIFAR-10数据集上,用1000张标注图像训练初始模型,为剩余未标注图像生成伪标签,迭代训练后分类准确率从70%提升到85%。

3.2 协同训练(Co-Training)

思想 :假设数据有多个独立视图(如图像的颜色和纹理),每个视图足以进行分类。训练多个模型,互相为未标注数据生成伪标签。

算法步骤

  1. 将特征分为 KK 个视图:x=[x(1),x(2),…,x(K)]x = [x^{(1)}, x^{(2)}, \dots, x^{(K)}]。
  2. 对每个视图 kk,用 Dl(k)={(xi(k),yi)}i=1lD_l^{(k)} = {(x_i^{(k)}, y_i)}_{i=1}^l 训练模型 f(k)f^{(k)}。
  3. 对未标注数据 xjx_j,每个模型 f(k)f^{(k)} 预测伪标签 yj(k)\hat{y}_j{(k)}。
  4. 选择高置信度伪标签,将 (xj,y^j(k))(x_j, \hat{y}_j^{(k)}) 加入其他模型的训练集。
  5. 重复 3-4 直到收敛。

伪代码

复制代码
    def co_training(D_l, D_u, K, models, threshold, max_iterations):
    for t in range(max_iterations):
        for k in range(K):
            # 训练第k个视图的模型
            models[k].fit(D_l[k])
            # 为未标注数据生成伪标签
            pseudo_labels = []
            for x in D_u:
                pred = models[k].predict(x[k])
                prob = max(softmax(pred))
                if prob > threshold:
                    pseudo_labels.append((x, argmax(pred)))
            # 更新其他模型的训练集
            for m in range(K):
                if m != k:
                    D_l[m].extend(pseudo_labels)
        if not pseudo_labels:
            break
    return models
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-15/fqk7UHu4hRmw1ACcLsyVKvazEO5l.png)

数学表达

  • 每个视图的监督损失:Ll(k)=1l∑i=1lCE(f(k)(xi(k)),yi)L_l^{(k)} = \frac{1}{l} \sum_{i=1}^l \text{CE}(f{(k)}(x_i{(k)}), y_i)。
  • 联合优化:最小化所有视图的损失,并通过伪标签共享信息。

优点

  • 利用多视图信息,减少单一模型的偏差。
  • 适合多模态数据(如图像+文本)。

缺点

  • 要求视图独立且互补,现实中难以满足。
  • 多个模型的训练和协调增加计算成本。

适用场景

  • 数据具有天然多视图结构(如网页的文本和链接)。
  • 多模态任务(如视频分类中的图像和音频)。

例子

  • 在网页分类中,视图1是网页文本,视图2是超链接结构。两个模型互相为未标注网页生成伪标签,最终分类F1分数提升10%。

3.3 图-based 方法(Graph-Based Methods)

思想 :将数据表示为图,节点是数据点,边表示相似性,通过标签传播为未标注数据分配标签。

算法步骤

构建图 G=(V,E)G = (V, E),节点 V={x1,…,xl+u}V = {x_1, \dots, x_{l+u}},边权重 wij=exp⁡(−∣∣xi−xj∣∣2/σ2)w_{ij} = \exp(-||x_i - x_j||^2 / \sigma^2)。

定义目标函数,平衡标注数据损失和平滑性:

min⁡f∑i=1l(f(xi)−yi)2+λ∑i,jwij(f(xi)−f(xj))2\min_f \sum_{i=1}^l (f(x_i) - y_i)^2 + \lambda \sum_{i,j} w_{ij} (f(x_i) - f(x_j))^2

优化目标函数,得到未标注数据的标签 f(xj)f(x_j)。

伪代码

复制代码
    def graph_based_ssl(D_l, D_u, sigma, lambda_reg):
    # 构建相似性图
    W = compute_similarity_matrix(D_l + D_u, sigma)
    # 初始化标签
    f = zeros(l + u)
    for (x_i, y_i) in D_l:
        f[i] = y_i
    # 标签传播
    for t in range(max_iterations):
        f = (1 - lambda_reg) * f + lambda_reg * W @ f
        # 固定标注数据的标签
        for (x_i, y_i) in D_l:
            f[i] = y_i
    return f[l+1:]
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-15/QJCyGmXKFfWv0ac1AxNIe4z56L7U.png)

数学推导

目标函数为:

J(f)=∑i=1l(fi−yi)2+λ∑i,jwij(fi−fj)2J(f) = \sum_{i=1}^l (f_i - y_i)^2 + \lambda \sum_{i,j} w_{ij} (f_i - f_j)^2

矩阵形式:

J(f)=(f−y)T(f−y)+λfTLfJ(f) = (f - y)^T (f - y) + \lambda f^T L f

其中,L=D−WL = D - W 是图拉普拉斯矩阵,DD 是度矩阵。

优化解:

f=(I+λL)−1yf = (I + \lambda L)^{-1} y

其中 yy 是部分已知标签。

优点

  • 充分利用数据间的几何结构,适合流形假设。
  • 标签传播直观,理论支持强。

缺点

  • 图构建和优化计算复杂度高(O((l+u)2)O((l+u)^2) 或更高)。
  • 对相似性度量(如 σ\sigma)敏感。

适用场景

  • 数据具有明确几何结构(如社交网络)。
  • 小规模数据集(大规模数据计算成本高)。

例子

  • 在社交网络中,节点是用户,边是好友关系。用少量标注用户(如“正向情感”)通过标签传播预测未标注用户的态度。

3.4 一致性正则化(Consistency Regularization)

思想 :模型对未标注数据的扰动版本(如数据增强、噪声)应有相似的预测。

算法步骤

对未标注数据 xjx_j,生成扰动版本 x~j=Augment(xj)\tilde{x}_j = \text{Augment}(x_j)。

计算一致性损失:

Lu=1u∑j=l+1l+u∣∣fθ(xj)−fθ(x~j)∣∣2L_u = \frac{1}{u} \sum_{j=l+1}^{l+u} || f_\theta(x_j) - f_\theta(\tilde{x}_j) ||^2

联合优化:

L=Ll+λLuL = L_l + \lambda L_u

代表算法

Π-Model

对同一未标注数据,模型在两次随机扰动下的预测一致。

损失:

Lu=1u∑j∣∣fθ(xj;ξ1)−fθ(xj;ξ2)∣∣2L_u = \frac{1}{u} \sum_{j} || f_\theta(x_j; \xi_1) - f_\theta(x_j; \xi_2) ||^2

其中 ξ1,ξ2\xi_1, \xi_2 是随机扰动。

Mean Teacher

维护教师模型(参数为学生模型的指数移动平均):

θt′=αθt′+(1−α)θt\theta_t' = \alpha \theta_t' + (1 - \alpha) \theta_t

学生模型预测与教师模型一致:

Lu=1u∑j∣∣fθ(xj)−fθ′(xj)∣∣2L_u = \frac{1}{u} \sum_{j} || f_{\theta}(x_j) - f_{\theta'}(x_j) ||^2

FixMatch

对未标注数据进行弱增强(如平移)生成伪标签,若置信度 max⁡cfθ(xjweak)c>τ\max_c f_\theta(x_j^{\text{weak}})_c > \tau,则用强增强(如颜色抖动)计算一致性损失:

Lu=1u∑j1(max⁡cfθ(xjweak)c>τ)⋅CE(fθ(xjstrong),y^j)L_u = \frac{1}{u} \sum_{j} \mathbb{1}(\max_c f_\theta(x_j^{\text{weak}})c > \tau) \cdot \text{CE}(f\theta(x_j^{\text{strong}}), \hat{y}_j)

伪代码(FixMatch)

复制代码
    def fixmatch(D_l, D_u, model, tau, lambda_u, max_iterations):
    for t in range(max_iterations):
        # 监督损失
        L_l = 0
        for (x, y) in D_l:
            L_l += cross_entropy(model(x), y)
        # 无监督损失
        L_u = 0
        for x in D_u:
            x_weak = weak_augment(x)
            prob = softmax(model(x_weak))
            if max(prob) > tau:
                pseudo_label = argmax(prob)
                x_strong = strong_augment(x)
                L_u += cross_entropy(model(x_strong), pseudo_label)
        # 联合优化
        loss = L_l + lambda_u * L_u
        optimize(model, loss)
    return model
    
    
    python
    
    
![](https://ad.itadn.com/c/weblog/blog-img/images/2025-08-15/Q6E1d259gH0zRc7oZjB84CxluyAe.png)

优点

  • 与深度学习高度兼容,适合高维数据(如图像、文本)。
  • 一致性约束简单且有效,易于扩展。

缺点

  • 需要精心设计增强策略(如弱增强 vs 强增强)。
  • 计算成本高(多次前向传播)。

适用场景

  • 图像分类、文本分类等深度学习任务。
  • 未标注数据丰富,模型对扰动敏感。

例子

  • 在ImageNet上,FixMatch用10%标注数据(约13万张图像)训练ResNet,准确率接近全监督学习的90%。

3.5 生成模型方法(Generative Models)

思想 :利用生成模型(如变分自编码器 VAE 或生成对抗网络 GAN)学习数据和标签的联合分布 p(x,y)p(x, y),为未标注数据推断标签。

算法步骤

  1. 训练生成模型,最大化 p(x,y)p(x, y) 或其近似(如变分下界)。
  2. 对未标注数据 xjx_j,推断标签 y^j=arg⁡max⁡yp(y∣xj)\hat{y}_j = \arg\max_y p(y | x_j)。
  3. 用推断的标签联合训练分类器。

数学表达(以VAE为例)

目标:最大化对数似然 log⁡p(x,y)\log p(x, y)。

引入隐变量 zz,变分下界为:

L=Eq(z∣x,y)[log⁡p(x,y∣z)]−KL(q(z∣x,y)∣∣p(z))\mathcal{L} = \mathbb{E}_{q(z|x,y)}[\log p(x,y|z)] - \text{KL}(q(z|x,y) || p(z))

未标注数据的损失:近似 ∑yp(y)L(x,y)\sum_y p(y) \mathcal{L}(x, y)。

优点

  • 能捕捉复杂数据分布,适合生成和分类联合任务。
  • 理论上优雅,适合复杂数据。

缺点

  • 训练生成模型复杂,计算成本高。
  • 对数据分布假设强,稳定性较低。

适用场景

  • 数据分布复杂(如语音、医学影像)。
  • 同时需要生成和分类(如生成标注样本)。

例子

  • 在语音识别中,VAE学习语音特征和标签的联合分布,为未标注语音推断标签,提升转录准确率。

4. 半监督学习的典型应用

半监督学习在多个领域有广泛应用,以下是详细的场景描述、方法选择和实际案例。

4.1 图像分类

场景

  • 数据:少量标注图像(如CIFAR-10的4000张标注图像)和大量未标注图像(如网络爬取的图片)。
  • 任务:将图像分类为特定类别(如“猫”“狗”)。

方法

  • FixMatch :结合弱增强和强增强,生成伪标签并强制一致性。
  • MixMatch :用MixUp混合标注和未标注数据,生成伪标签。
  • UDA :使用强数据增强(如RandAugment)提升一致性。

案例

  • 数据集 :CIFAR-10(5万张图像,10类)。
  • 设置 :4000张标注图像(每类400张),其余未标注。
  • 结果 :FixMatch在ResNet-50上达到95%准确率,接近全监督学习的96%。
  • 细节 :弱增强为随机平移,强增强包括颜色抖动、裁剪和Cutout,置信度阈值 τ=0.95\tau = 0.95,无监督损失权重 λ=1\lambda = 1。

挑战与解决方案

  • 挑战:增强策略对性能影响大。
  • 解决方案:自动搜索增强策略(如AutoAugment)。

4.2 自然语言处理(NLP)

场景

  • 数据:少量标注文本(如IMDB的情感分析数据集,2500条标注评论)和大量未标注文本(如网络论坛帖子)。
  • 任务:情感分类(如“正面”“负面”)或语义分析。

方法

  • UDA :对未标注文本应用回译(back-translation)增强,强制一致性。
  • 伪标签 :用BERT为未标注文本生成伪标签,迭代训练。
  • 协同训练 :结合文本和元数据(如用户评分)作为多视图。

案例

  • 数据集 :IMDB(5万条评论,2类)。
  • 设置 :2500条标注评论,剩余未标注,额外爬取10万条未标注评论。
  • 结果 :UDA结合BERT达到88%准确率,接近全监督的90%。
  • 细节 :回译增强将英文评论翻译为法语再译回英文,伪标签置信度阈值 τ=0.9\tau = 0.9。

挑战与解决方案

  • 挑战:文本增强可能改变语义。
  • 解决方案:使用语义保持增强(如同义词替换)。

4.3 语音识别

场景

  • 数据:少量标注语音(如LibriSpeech的100小时转录音频)和大量未标注语音(如YouTube音频)。
  • 任务:语音转文字。

方法

  • 自训练 :用标注数据训练初始模型,为未标注语音生成伪转录。
  • Mean Teacher :用教师模型强制一致性,处理语音噪声。
  • 生成模型 :用VAE建模语音特征和文本的联合分布。

案例

  • 数据集 :LibriSpeech(960小时语音)。
  • 设置 :100小时标注语音,860小时未标注。
  • 结果 :Mean Teacher结合Wav2Vec 2.0降低字错误率(WER)从15%到10%。
  • 细节 :数据增强包括添加背景噪声和速度扰动,教师模型平滑参数 α=0.99\alpha = 0.99。

挑战与解决方案

  • 挑战:语音噪声影响伪标签质量。
  • 解决方案:使用噪声鲁棒的增强和置信度过滤。

4.4 医疗影像分析

场景

  • 数据:少量标注医学图像(如1000张标注的肺部CT)和大量未标注图像(如医院数据库)。
  • 任务:检测病变(如肺结节)。

方法

  • 图-based方法 :构建图像相似性图,传播病变标签。
  • FixMatch :用弱增强(如平移)和强增强(如对比度调整)生成伪标签。
  • Mean Teacher :用教师模型处理未标注图像的变异性。

案例

  • 数据集 :LUNA16(肺结节检测数据集)。
  • 设置 :1000张标注CT,5000张未标注CT。
  • 结果 :FixMatch结合3D CNN提升F1分数从0.75到0.85。
  • 细节 :弱增强为随机平移,强增强包括旋转和噪声,置信度阈值 τ=0.95\tau = 0.95。

挑战与解决方案

  • 挑战:医学图像标注稀缺,伪标签错误代价高。
  • 解决方案:结合领域知识(如解剖学约束)过滤伪标签。

4.5 推荐系统

场景

  • 数据:少量用户评分(如电影评分)和大量未标注用户行为(如浏览记录)。
  • 任务:推荐个性化内容。

方法

  • 协同训练 :结合用户行为(浏览)和元数据(电影类型)作为多视图。
  • 图-based方法 :构建用户-物品图,传播评分标签。
  • 伪标签 :为未评分物品生成伪评分。

案例

  • 数据集 :MovieLens(100万条评分)。
  • 设置 :10万条标注评分,1000万条未标注行为数据。
  • 结果 :图-based方法提升推荐准确率从80%到85%。
  • 细节 :图边权重基于用户行为相似性,标签传播迭代10次。

挑战与解决方案

  • 挑战:用户行为数据稀疏。
  • 解决方案:结合矩阵分解和半监督学习。

5. 总结与学习建议

5.1 总结

  • 数学框架 :半监督学习通过联合优化监督损失和无监督损失,利用标注和未标注数据。核心是设计合适的无监督损失(如一致性正则化、伪标签、图正则化),并通过梯度下降优化。
  • 主要方法 :自训练、协同训练、图-based方法、一致性正则化、生成模型各有优劣,适用于不同场景。FixMatch等现代方法在深度学习中表现突出。
  • 典型应用 :图像分类、NLP、语音识别、医疗影像、推荐系统等领域,半监督学习显著降低标注成本并提升性能。

5.2 学习建议

  1. 理论学习

    • 深入理解损失函数设计和优化,推荐书籍《Semi-Supervised Learning》(Chapelle et al.)。
    • 学习矩阵分解和图论,理解图-based方法。
  2. 论文阅读

    • FixMatch(Sohn et al., 2020)
    • Mean Teacher(Tarvainen & Valpola, 2017)
    • UDA(Xie et al., 2020)
  3. 实践项目

    • 在CIFAR-10上实现FixMatch,比较不同增强策略。
    • 用Hugging Face的BERT实现文本分类的伪标签方法。
    • 尝试医疗影像数据集(如CheXpert)的小规模实验。
  4. 工具与资源

    • 框架 :PyTorch、TensorFlow、Hugging Face。
    • 数据集 :CIFAR-10、IMDB、LibriSpeech、LUNA16。

全部评论 (0)

还没有任何评论哟~