Advertisement

知识蒸馏: Distilling the Knowledge in a Neural Network(中)

阅读量:

下一篇已深入探讨了论文的第一二章内容,其中对知识蒸馏的核心机制进行了系统化地阐述。

本文将介绍论文后续图表部分的内容。具体而言,在MNIST数据集和语音识别数据集上分别展示了知识蒸馏算法的实验结果。具体比较了通用模型与专家模型之间的差异性,并分析了多专家系统的优势。

实验部分

MNIST数据集

之所以被称为Preliminary/预实验的原因是因为MNIST数据集相对简单,并且所使用的模型规模较小。通常会将其作为初步探索使用。

该研究在MNIST数据集上采用了相对简单的架构。具体来说是包含两个隐藏层的神经网络结构。每个隐藏层均配置了1200个带有ReLU激活函数的神经元。为了防止过拟合问题,在网络设计中采用了DropOut技术。为了探索知识蒸馏的效果,在论文中我们采用了单个大型神经网络作为教师模型,并将其应用于所有6万条训练样本。通过采用DropOut技术和权重约束(如文献[5]所述)实现了强大的正则化效果。 Dropout技术可被视为训练大量共享参数模型的一种方法

该模型在测试集上的表现出现了显著问题。论文还设计了一个较为紧凑的小型网络架构(即每个隐藏层仅包含800个ReLU单元且未采用正则化方法),该小型网络在训练过程中出现了146个错误。然而,在大网生成的学生目标分布上引入软目标任务(即当温度参数设置为20时),通过匹配这些软目标输出来调节小型网络的行为,则使该小型网络仅出现了74个错误。这一发现表明软目标能够有效地将知识从教师网络转移到学生网络中(即使转移集合中并不包含任何平移操作)。

接下来论文继续降低网络复杂度
随后,在经过压缩后的网络中存在至少300个神经元时,在其两个隐藏层中取得一致的效果。
当这一数量被大幅减少至每层仅包含30个神经元时,
温度范围在2.5至4之间表现出了显著的优势,
这明显优于更高或更低温度的情况。

  • 每个隐藏层仅约300个神经元时,在温度值大于8的情况下都能带来相似的结果。
  • 如果将每个隐藏层缩减至仅30个神经元(即非常小规模的网络),研究发现2.5至4℃的最佳温度范围(超出或低于该范围时效果欠佳),但并未提供具体达到何种效果的数据。

随后,在进行知识蒸馏的过程中

进一步的实验:
大部分错误源于模型对3类学习到的偏差过低。当我们将该偏差增加至3.5(该值在测试集整体性能优化中表现最佳),精简后的模型产生了109个错误,在这些错误中涉及数字3的有14个。因此,在适当调整偏差后,即使在整个训练过程中从未见过数字3,精简后的模型仍能准确识别出约98.6%的测试用例中的数字3。

If the transfer set consists solely of digits seven and eight sourced exclusively from the training set, then achieving an error rate of forty-seven point three percent is observed in terms of test results; however, upon minimizing bias adjustments for both seven and eight by an increment of seven point six with an aim to enhance testing efficiency, this figure significantly reduces down to thirteen point two percent as measured in tests.

Automatic Speech Recognition ASR数据集

In this section, we examine how ensemble learning techniques applied to Deep Neural Network (DNN) acoustic models impact their performance in Automatic Speech Recognition (ASR). Our results demonstrate that the proposed knowledge distillation strategy successfully compresses an ensemble of DNN models into a single, more efficient model, which outperforms standalone models trained directly on the same dataset. Subsequently, in our research, we employ this method to validate its effectiveness in the domain of automatic speech recognition. We utilize datasets such as MNIST to test our approach and evaluate its performance. Due to MNIST's simplicity, we opt for additional datasets with different modalities and larger scales to further validate our findings regarding knowledge transfer and model compression.

现有的语音识别系统主要采用深度神经网络(DNN)来将波形特征提取出的简短时域上下文映射到隐式马尔可夫模型(HMM)离散状态的概率分布上[4]。具体而言,在每个时间点上,DNN会生成一个关于三音节状态集群的概率分布,并且解码器会确定一条通过HMM状态的最佳路径,这条路径在平衡使用高概率状态以及生成符合语言模型的转录之间达到了最佳折中效果。
尽管在理论上可以通过综合考虑解码器(以及语言模型)并通过对所有可能路径进行积分的方式来训练DNN以达到最佳效果,但通常的做法是通过在局部范围内最小化交叉熵损失来进行帧级分类训练,DNN在这种情况下预测当前状态的概率分布与经过强制对齐后的实际状态标签之间的差异最小化

where θ represents the set of parameters in our acoustic model P, which corresponds to the probability distribution over HMM states given an observation at time t and frame st. This distribution reflects the likelihood that frame st corresponds to state ht when forced alignment with the correct word sequence is applied. The model utilizes a distributed stochastic gradient descent approach for training.

一般用于语音识别的深度神经网络(DDN)都采用将语音格式(waveform)中的时序上下文特征(temporal context of featrurue)映射至离散隐马尔可夫模型(HMM)的概率分布的方式进行建模。这种HMM的核心思想在于当前状态仅受前一状态的影响,在语言处理领域即为基于n-gram的语言模型,在此框架下前面一个词会用来预测后面可能出现的一个词的概率分布。通过这种方式构建完整的语义序列达到语音转文字的目的。
在上文中,“find a path through the HMM states”我的理解是指搜索并找到一条最符合语法规律的最佳路径。
大体上是这样子的吧?关于语音处理的部分我还不是很清楚。
训练这个DNN的一般方法是以帧为单位进行分类任务来进行学习过程的理解——即将连续语音信号中的每一帧数据抽象成独立的任务进行分类工作。该分类任务采用交叉熵损失函数作为优化目标,并基于上述公式对模型参数θ进行迭代优化以实现目标函数最小化。

We employ an architecture comprising eight hidden layers, each equipped with 2560 rectified linear units, culminating in a final softmax layer comprising 14,000 labels (HMM targets ht). The input comprises 26 frames derived from Mel-scaled filter bank coefficients, each contributing 40 features with a temporal step size of 1. While predicting the HMM state for the subsequent frame after advancing by a time interval of approximately ten milliseconds. The total parameter count approximates to around eighty-five million. This rendition represents an older version utilized within Android voice search systems and serves as a formidable baseline. To train our DNN acoustic model, we utilize approximately two thousand hours of spoken English data, generating over seven hundred million training instances. Our system achieves a frame accuracy rate of fifty-eight point nine percent and attains a word error rate (WER) of ten point nine percent on our validation set.

  • 基于Android系统的语音搜索模型,在该领域属于较早时期使用的基准模型。
  • 该模型包含8个隐藏层(即各层配置了2560个ReLU神经元)。
  • 在识别任务期间涉及14, 也就是说涉及了多少类别?
  • 输入的数据类型尚未完全理解。
  • 模型参数规模为8.5 million(即约8.5百万),具体来说就是包含了大约3亿多条数据样本。
  • 使用了7亿多条训练数据,并耗时两天多的时间完成训练过程。
  • 其准确率达到58.9%,单词错误率维持在10.9%左右。
在这里插入图片描述

下面就比较了单个模型,集成模型和知识蒸馏的效果。

集成模型。我们训练了10个不同的模型来预测P(ht|st; θ),采用了相同的架构和训练流程(与基准方法相同)。这些模型被随机初始化为不同的初始参数值,并发现这足以使集成的平均预测显著优于单个模型的预测效果。我们还尝试通过让每个模型看到不同的数据集来增加多样性[1](相当于训练了十个互不关联但具有多样性的模型),但发现这种方法对结果没有明显改善作用[2]。因此我们选择了更为简单的方法——仅使用统一的数据集进行训练

Knowledge distillation技术如表所示,在经过该过程后并未显著降低性能表现

最后一段提到文献8也进行了相关研究。然而,在温度设置为1的前提下利用大量未标记数据进行知识蒸馏的过程中,并未取得显著成效。具体而言,在两者均采用硬标签训练的情况下,在大模型与小模型之间存在的误差差距上只减少了约28%。

通用+专用模型(Training ensembles of specialists on very big datasets)

其实在一定程度上与知识蒸馏关联不大;尽管可能存在某种训练策略上的变化。该论文提出了一种基于通用框架与专用优化器相结合的集成方案。

Creating a group of models represents a straightforward approach for leveraging parallel computing capabilities. The common concern regarding this method - that extensive computations would be needed during testing - can effectively be addressed through distillation techniques. However, there's another significant issue with such ensembles: if each individual model constitutes a sizable neural network and the dataset itself may also be vast, the computational burden during training becomes significant despite being relatively straightforward to distribute computations.

In this section, we provide an illustration of such a dataset and demonstrate how specialist models, each dedicated to a distinct confusable class subset, can minimize the computational resources needed to train an ensemble. The primary challenge with specialists that concentrate on distinguishing fine details is that they tend to overfit easily, and we explain strategies to mitigate this issue through the use of soft targets.

论文指出,在训练模型时存在多种并行策略。其中一种方法是采用集成架构,在主训练过程中各自主训练各自的数据集以实现协同优化。当进入推理环节中时,我们可以将这些集成网络作为教师节点,并通过知识蒸馏技术构建一个学生网络。这个学生网络可发展成为一个轻量级的学生网络用于推理与部署。

在这种情况下,并行通常遵循专家系统模式来进行任务分配。具体来说就是需要将大规模的数据集划分为多个细分领域,并由各个模型分别管理一个细分领域。例如汽车领域可能分为不同品牌或功能类型。

论文探讨了soft target如何缓解这种专家系统中的过拟合问题,并解释道这是因为这些大模型在小数据集上进行领域特定训练(即仅针对特定领域数据进行训练,并且没有接触其他类型的数据),因此容易发生过拟合。

JFT数据集

在本研究中,我们引入了一个新的数据集来评估该问题的性能——JFT。该数据集由谷歌内部提供,并包含1亿张标注图片和1.5万种标签。在我们开展这项研究时……

JFT属于谷歌内部的一个专门的数据集合。该集合涵盖了多达1.5万个不同的分类标签,并积累了超过一亿份标注数据。基于该数据集合进行训练的Google基准模型需要耗时六个月才能完成训练过程。

然后论文介绍了训练这个baseline模型的几种并行方式:

在训练过程中,并行计算系统采用了将神经网络分配到不同核心组并处理不同mini-batch的方法。每个replica在当前mini-batch上计算平均梯度,并将其传递给分片参数服务器以获取更新值。这些更新值反映了自上次参数发送以来服务器收到的所有梯度。

Second, each replica is spread over multiple cores by allocating different subsets of the neurons to each core.

Ensemble training is also a form of parallelism that can encapsulate the other two types, though it requires several more additional cores. Leaving us with no option but to look for a faster way to improve the baseline model, we needed an optimized solution that could scale effectively. Below, we will introduce the specific optimization methods tailored for this type of parallelism, focusing on enhancing computational efficiency through dedicated model architectures.

Specialist Models

Whenever there exists a large number of categories, it becomes sensible to structure such a complex architecture as an ensemble composed primarily of one general-purpose model trained comprehensively alongside numerous specialized models. Each specialized model focuses exclusively on datasets characterized by a high concentration of examples from categories that are highly ambiguous or difficult to distinguish (e.g., various varieties of mushrooms). By aggregating all categories that such a specialist finds unimportant into a single 'dustbin' category, its softmax distribution can be significantly reduced.

在论文中描述的一种情况是,在这一类数据集具有丰富的分类需求的情况下(例如JFT),其包含超过15,000个不同的种类(类别)。由于单个模型难以全面覆盖所有分类任务(即无法做到对每一个类别都进行准确识别),因此需要采用多组模型协同工作以分别处理不同的子类问题。

因此,在论文所提出的方法中存在一种被称为通用模型(generalized\ model)以及多个专门针对特定领域的小专家型模型(specialized\ models)。这些小专家型模型的作用在于识别通用模型难以准确分类的细分数据类型(confusable\ data\ subsets)。

在这些专业领域模型中,在其生成的概率分布中概率较低的类别(即logit值较小)可以被视为该专业领域模型无法提供高置信度的对象。这意味着该专业领域模型对无法准确识别的对象表现出较低的自信。然而,在这种情况下,此类对象通常会被归类到一个通用类别中:'垃圾类'(dustbin class)。

为了减少过拟合并促进学习低层特征检测器的努力共享,每个专门模型将从一般模型中继承权重。

为了减小过拟合现象的发生,在构建专门模型时采用了通用模型的权重设置。这些专门模型在构建过程中不仅模仿了通用模型对颜色、角和曲线等基础图像特征的识别能力。

These weights are adjusted following specialized training, where half their examples originate from a dedicated subset and are complementarily sourced through random sampling from other available samples. Once trained, we can correct for potential biases introduced during this process by incrementing a bias correction term associated with a designated 'dustbin' class. This correction involves adding to this term a logarithmic value representing how much more prevalent or overrepresented our primary interest group has been relative to an assumed baseline distribution.

  • 在训练这些specialist model的数据集合时, 其中一半是由该领域专门收集的数据构成, 半数以上则来自整个原始数据集中的随机采样. 为了防止过拟合, 这两种方法通常会配合使用.
    • 其后一段阐述道: 其中一半来自dustbin数据集涵盖了大部分类别, 比如这个模型专为汽车设计, 汽车细分领域占有一半. 剩下的15000个类别中有14000个也被覆盖到了. 因此在单个样本上的分布就会显得稀疏了, 为了减小过拟合风险, 将这些dustbin数据中的类别偏差(bias)提升至与主模型一致水平, 这样一来其分类结果(logit)也会相应得到提升, 最终就能有效降低过拟合风险.

如何划分数据集( Assigning classes to specialists)

论文的主要目的是针对通用模型容易产生混淆的数据类别来进行专用模型数据集的分类工作然后基于这些类别的混淆矩阵来进行聚类分析(这里指的是将通用模型输出各类别的混淆矩阵中的每一个项对应的数据作为一个独立类别从而形成稳定的聚类群组)然而该论文并未采用上述方法而是采用了另一种聚类方法。

该论文提出了一种新的聚类方法即基于通用模型输出各类别的混淆矩阵将其中每一个项所对应的作为独立的一个类别从而形成稳定的群组这种思路类似于传统的层次聚类方法但与现有的某些聚类算法有所不同。

Especially, we utilize a clustering algorithm on the covariance matrix derived from our generalist model's predictions. This approach identifies sets of classes S^m, which are frequently predicted together, and employs these as target classes for one of our specialist models, denoted by m. To achieve this, we implemented an online variant of the K-means algorithm targeting specific columns within the covariance matrix. As demonstrated in Table 2, this method yielded reasonable clusters. Various alternative clustering techniques were tested, yielding comparable outcomes.

该采用经典的K-means方法来进行聚类分析。研究中发现采用其他聚类方法效果大体相似。

最终预测出来的分类情况就如表所示:

在这里插入图片描述

全部评论 (0)

还没有任何评论哟~