Advertisement

【论文笔记】TinyBERT: Distilling BERT for Natural Language Understanding

阅读量:

To enhance inference speed and minimize model size while preserving accuracy, we initially introduce a specialized Transformer distillation technique tailored for knowledge distillation (KD) of Transformer-based models. Subsequently, we present a refined two-phase learning architecture for TinyBERT, implementing Transformer distillation across both pretraining and task-specific learning stages.


The complexity of Transformers escalates with increased model depth (greater parameter counts), leading to enhanced redundancy. However, structural simplification results in significant performance impacts. This paper's core contribution lies in devising a unique Transformer-specific distillation methodology.


There have been many model compression techniques (Han et al., 2016) proposed to accelerate deep model inference and reduce model size while maintaining accuracy. The most commonly used techniques include quantization (Gong et al., 2014),weights pruning (Han et al., 2015), and knowledge distillation (KD) (Romero et al., 2014).


  • 模型量化:经过量化算法对数值进行压缩和解压缩,从而达到减小模型大小和加速运算的目的。几乎所有量化方法都能实现压缩,但是并不是所有量化方法都能实现加速。量化实现加速的两个重要条件,首先量化算法要简单不引入过多额外计算开支,其次硬件方面适用运算库进行运算加速。因此量化在实用中比较难,尤其对于我这种不太懂硬件的。
  • 模型剪枝:根据我短时间的了解,剪枝方法分为结构化剪枝和非结构化剪枝。结构化剪枝即在大粒度上对模型进行结构级修剪,如剪掉卷积层中多余的卷积核(这似乎是我查阅资料中见过的唯一用法);非结构化剪枝即粒度更低的参数级修剪,根据L1范数或L2范数等对全体参数进行评估,按照比例将最不重要一部分参数置零,最终得到稀疏模型(最终模型的结构不发生变化,参数量也不会发生变化,仅仅是变得稀疏),通过一些稀疏矩阵分解的方法能够达到压缩的目的,但是只有在特定的支持稀疏矩阵运算的硬件上才能达到加速,因此实用性相对结构化剪枝更差。
  • 知识蒸馏:即训练学生教师模型的方法。值得一提的是知识蒸馏和结构化剪枝具有一定的相似性,二者都追求在模型结构方面进行缩减从而得到一个去冗余的小模型。剪枝的优势在于可以逐层剪枝并且在训练过程中剪枝,蒸馏的优势在于可以获取更多的泛化信息。二者或许可以相辅相成,也可能殊途同归。

双阶段预训练与微调体系主要包含两个关键环节:前期采用大规模无监督文本语料库进行基础模型构建;后续则聚焦于特定任务数据进行优化。这一过程显著提升了目标模型在微调阶段的表现难度。因此,在整个训练流程中都需要设计有效的知识蒸馏(KD)策略。


预训练模型的训练主要包含两个关键环节:前期采用无监督学习方式进行基础模型构建;后期则基于特定任务数据进行优化。值得注意的是,在蒸馏策略的设计上需要特别关注的是这两个不同阶段的具体实现差异。


我们设计了三种类型的损失函数以适应BERT层的不同表示形式。
具体来说:

  1. 来自嵌入层的输出;
  2. 自Transformer层推导出的状态和注意力矩阵;
  3. 由预测层产生的logits。

初始阶段的学生旨在通过蒸馏方法复制教师模型的行为来理解输入与输出之间的关系。

尽管随着模型复杂程度的提升,

从输入到输出的变化呈现出显著跨度,

因此该研究建议专注于训练若干中间层神经网络以获取更丰富的特征。

这些改进不仅有助于丰富学生模型的知识储备,

同时也带来了实现上的挑战。


we introduce an innovative two-stage learning paradigm that incorporates both the general distillation mechanism and the task-specific distillation approach, as demonstrated in Figure 1.

在这里插入图片描述

假设学生模型包含M个Transformer层(此处M<N),而教师模型包含N个Transformer层。我们从教师模型中选择M个(其中M<N)Transformer层来进行知识蒸馏。定义函数n = g(m),其中g是从学生层到教师层的索引映射函数。第m个学生层从第g(m)个教师层中获取信息。
.
形式上,在蒸馏过程中学生通过最小化以下目标函数来实现知识获取:

在这里插入图片描述

其中Llayer代表某个特定模型层(如Transformer层或嵌入层)所使用的损失函数;fm(x)代表第m个层次所诱导的行为函数;λm是一个超参数变量,用于衡量第m个层次蒸馏过程的重要性程度。由于涉及到模型层数的缩减以及选择若干关键中间节点进行学习训练的问题存在,在这种情况下就必然涉及到学生层与教师层之间的对应关系问题;至于哪种对应关系能够达到最佳效果,则需通过实验来验证得出结论;最终实现的蒸馏损失函数则是各中间层次贡献而成;在文章所提供的实验部分中所采用的方法是均匀分配策略:将原本12个层次压缩至4个蒸馏节点。


The innovative transformer-based layer-wise distillation method comprises both attention mechanisms-based and hidden states-based approaches, as illustrated in Figure 2.

这里是引用

The student acquires the matrices of multi-head attention within a teacher network, and the objective is formulated as:

在这里插入图片描述

where h represents the number of attention heads, Ai ∈ Rl×l denotes the attention matrix associated with the i-th head of teacher or student, l stands for the input text length, and MSE() refers to the mean squared error loss function.
*
*
*
在文章中反复强调的一点是Bert的注意力矩阵包含了大量丰富的语义信息,并且非常重要。因此将其作为蒸馏过程中的一个关键节点之一,并采用平均方法计算损失函数。另外值得注意的是,在此过程中无需对矩阵进行Softmax操作。

Besides attention-driven knowledge distillation, we additionally employ the Transformer layer's output to extract such information. The goal of this approach is outlined below:

在这里插入图片描述

隐层状态蒸馏的一项主要挑战在于:教师模型与学生模型所处的隐层维度可能并不相同(Differ)。为此作者提出了一种创新方法:通过引入可学习矩阵W实现了这一目标。


改写说明

在这里插入图片描述

where ES和ET分别表示学生网络和教师网络的嵌入,在本文中它们与隐藏状态矩阵具有相同的维度。We矩阵是一个线性变换,在功能上类似于Wh. 在Bert架构中,嵌入层的维度与隐藏状态层具有相同的维度。因此此蒸馏节点采用的方式与隐状态蒸馏一致。


除了模仿中间层的行为之外,在Hinton等人的(2015)研究中我们还采用知识蒸馏的方式拟合教师模型的预测结果。
具体而言,在学生网络的输出与教师网络输出之间计算了软交叉熵损失,并对其进行最小化。
如图1所示,在MNIST数据集上测试发现该方法表现优异。

在这里插入图片描述

其中zS和zT分别表示学生模型和教师模型预测得到的logits向量。CE代表交叉熵损失函数。t表示温度参数,在实验中我们发现当t=1时表现优异。

  • 这里就是蒸馏分类器的logits, 是最标准的蒸馏方法. 文中提到了软交叉熵, 没太明白什么意思.

这里是引用

在微调阶段实施蒸馏过程时,作者采用了数据增强技术。具体而言,在这一过程中,作者主要基于BERT模型和GloVe向量对词语进行同义词替换操作。


实验结果

在这里插入图片描述

The introduced two-step TinyBERT learning framework includes three core components: General Distillation, Task-specific Distillation, and Data Augmentation. Performance results from eliminating each specific training phase are examined and displayed in Table 2.

在这里插入图片描述

The empirical results demonstrate that all three procedures are essential to the success of our proposed method. Specifically, both Transfer Learning (TL) and Domain Adaptation (DA) exhibit comparable performance across four distinct tasks. Notably, task-specific approaches, particularly TL and DA, outperform generic pre-training methods like Grand Canonical Discrimination (GCD) across all evaluated tasks.


Investigating the impact of distillation objectives on TinyBERT learning, this study explores various baseline approaches, specifically excluding Transformer-layer distillation (w/o Trm), embedding-layer distillation (w/o Emb), or prediction-layer distillation (w/o Pred). Namely, these methods have been systematically evaluated to assess their effectiveness. The experimental outcomes are comprehensively demonstrated in Table 3, revealing that all tested strategies demonstrate notable utility.

在这里插入图片描述

可以看出,在蒸馏方法中存在多种有效途径。其中最为关键的是中间结果的提取环节(即解题过程),因为这一阶段对于最终解答的价值远高于最终的答案本身。


We also examine the impacts of various mapping functions n = g(m) on the TinyBERT learning process. Originally, our TinyBERT implementation, as outlined in Section 4.2, employs a uniform strategy. In this study, we contrast this approach with two alternative strategies: one characterized by a top-down adjustment (wherein g(m) is defined as g(m) = m + N − M with 0 < m ≤ M), and another adopting a bottom-up approach (whereas for the bottom-strategy, defined as g(m) = m with 0 < m ≤ M). The experimental outcomes are summarized in Table 4.

在这里插入图片描述

Our experiments show that top-strategy achieves superior performance compared to bottom-strategy on MNLI, while performing worse in MRPC and CoLA, which corroborates the understanding that different tasks rely on knowledge from distinct BERT layers. Uniform strategy integrates knowledge across all layers of BERTBASE and outperforms other baseline methods across every task.

全部评论 (0)

还没有任何评论哟~