大模型都在用的GQA是什么
论文题目:从多头检查点训练广义多查询Transformer模型
更详细内容直接看原文!!!
摘要
Single-head multi-query attention mechanism significantly enhances the decoding speed. However, this approach may incur performance losses and is not specifically designed to improve inference efficiency by training a single model. We propose a solution.
将现有的多头语言模型检查点升级成MQA模型,
采用分组注意力机制GQA作为一种扩展形式来提升性能,并在多个注意力头之间插入中间层以优化计算流程
我们表明,向上训练的GQA以接近MQA的速度达到接近多头注意力的质量。
导言
自回归解码器推理在Transformer架构中存在显著性能瓶颈问题。具体而言,在计算过程中为解码器层加载权值矩阵及对应注意力键值对会产生较大的内存带宽消耗。为此提出了一种基于multi-query attention的新方法。该方法允许多个查询头并采用单个键值头进行计算。然而该方案存在潜在性能损失风险同时可能导致训练过程不稳定的问题。现有研究表明部分语言模型采用了multi-query attention机制如PaLM但仍有大量主流语言模型未采纳包括公开可访问的T5及LLaMA等关键模型类型
这项工作包含了对使用大型语言模型进行更快的推理的两个贡献。
首先表明研究表明采用多头注意力机制(MHA)的预训练语言模型架构能够被向上微调;这种微调方法仅占用少量原始训练计算资源并可实现MQA;这种方法提供了快速多查询与高质量MHA架构相结合的有效解决方案
其次,在研究中开发了一种新的关注机制称为分组查询注意(GQA)。这种机制能够有效地介于传统的多头注意力模型与现代的多轮对话注意力模型之间,并通过引入单一键编码与值编码的方式实现了两者的融合。研究表明,在实际应用中向上训练的GQA模型不仅能够达到接近真实场景下的对话质量水平,并且其计算效率也能够与当前主流的大规模对话系统相媲美。
Uptraining
从multi-head model生成multi-query model分两个步骤进行:
首先,转换检查点,
其次,进行额外的预训练,以使模型适应其新的结构。
图1详细描述了将multi-head checkpoint转换为multi-query checkpoint这一技术流程。通过平均池化技术将key 和 value 头的投影矩阵统一为一个单一的投影矩阵,并经过实验对比发现该方法显著优于单独采用单一键-值对或随机初始化新的key-value 头。
接着,在基于相同的基础预训练模型上进行处理后生成的校准点上实施基础训练流程,并在以小比例α进行的预训练阶段中对该校准点进行进一步优化。

该过程分析了如何将来自各个头中的Keys和Values进行投影并平均合并。
Grouped-query attention
分组查询注意力将每个查询头划分为G个子组,在每子中均共享一个键投影和值投影。其中GQA-G类指在各子中执行分组式计算的方式。当子数量为1时(即GQA-1),则仅存在单一键投影与值投影,在这种情况下等效于标准多对齐机制MQA。而当子数量达到H层时(即GQA-H),其等效于总共有H个独立的多对齐机制并行工作。图2直观对比了分组式多对齐机制与传统多对齐及多查询方式之间的性能差异。值得注意的是,在实际应用中若需将一个多对齐层转换为对应的分组式结构,则需要对该层中的所有原始键值特征进行平均池化处理后才能生成各子层的键与值特征向量
由一组中间数量决定的一个插值模型表现出色于MQA却又快于MHA这已被我们后续内容所证实并展示了这种平衡.该插值模型通过将每个H键及对应的值头缩减至单一键及其相应的值头从而降低了键值缓存所需的存储空间.然而在较大的模型架构中通常会增加多头的数量这使得multi-head attention机制在内存使用方面更为激进.GQA则允许我们随著模型规模的增长而维持带宽与容量的比例.
此外,在较大模型中内存带宽消耗较较少。这是因为kv缓存规模会随着模型维度的增长而扩大;然而,在这种情况下FLOPs与参数数量均按照模型维度的平方比例增长。值得注意的是,在这种设置下编码器表示仍然是并行计算的基础架构;因此内存带宽通常不会成为主要瓶颈。我们期待GQA能够在更大规模的应用中展现出卓越的效果

总结
MHA(Multi-head Attention)
MHA(Multi-head Attention)是由Google团队于1997年首次提出的经典NLP模型,在文献《Attention Is All You Need》中首次系统阐述了这一机制,并将其发展成为经典的多头注意力机制。该方法通过同时处理多个查询、键和值矩阵来实现信息捕捉与表示学习。
具体而言,MHA由若干个并列的自注意力模块构成,每个模块都能聚焦输入的不同部分.每个注意力头都拥有独立的感知区域(parameter sets),能够独自学习输入的各种特性.随后整合各头信息,经过线性转换过程,获得最终输出.
该方法的特点在于能够有效识别输入数据中的多种不同特性。具体而言,在处理词序列时,每个‘头’都可以聚焦于不同的维度。
MQA(Multi-Query Attention)
MQA是由Google团队于2019年提出的,在论文Fast Transformer Decoding: One Write-Head is All You Need中提出。它也是一种MHA的变体形式,并被用作自回归解码中的注意力机制。相较于MHA的不同之处在于,MQA通过让各个解码器头共享一份统一的Key和Value矩阵,而每个解码器头仅保留了一份独特的Query向量。从而显著减少了Key和Value矩阵所需的参数数量,以此来提高推理速度的同时也带来了精度上的损失。
GQA(Grouped-Query Attention)
此外,在论文GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints中,Google提出了另一种MHA变体——GQA(Grouped Query Attention)。这种机制将查询头划分为G组,并对每个Query单独保留参数矩阵;而对于每个组则共享一个共同的Key矩阵和一个共同的Value矩阵。值得注意的是,在这种机制中,默认情况下G=1的情况与传统的多头注意力机制(TMA)相同。
中间组的数量提升了插值模型的性能优于MQA。但相较于MHA而言运行速度更快。在转换过程中将H键及对应的值头缩减为单一的键及相应的值。这样会导致每次查询所需的内存数据量是原始H倍数。
