论文阅读-混合专家模型MOE-DAMEX:Dataset-aware Mixture-of-Experts for visual understanding of mixture-of-dataset
目录
Abstract
1. Introduction
contributions
3. 传统的MOE
3.1 Routing of tokens
3.2 负载均衡损失
1)专家的重要性损失(Importance Loss)
2)专家的负载损失(Load Loss)
4. 方法
4.1 setup
4.2 DAMEX: Dataset-aware Mixture-of-Experts
题目:DAMEX: Dataset-aware Mixture-of-Experts for visual understanding of mixture-of-datasets
数据集感知的专家混合模型,用于混合数据集的视觉理解
Abstract
通用普通的detector的构建提出了一个关键问题:我们如何才能在大型混合数据集上有效地训练模型。答案在于,在单个模型中,学习特定于数据集的特征并集成他们的知识。以前的方法通过在一个共同的主干上使用单独的检测头来实现这一点,但这会导致参数显着增加。
在这项工作中,我们提出了专家混合作为解决方案,突出了 MoE 比可扩展性工具多得多。我们提出了 Dataset-Aware Mixture-of-Experts, DAMEX,我们通过学习将每个数据集标记路由到其映射的专家来训练专家成为数据集的“专家”。在通用对象检测基准上的实验表明,我们比最先进的平均高出 +10.2 AP 分数,并且比我们的非 MoE 基线提高了平均 +2.0 AP 分数。我们还观察到在将数据集与 (1) 有限可用性 、(2) 不同领域 和 (3) 不同标签集混合 时的持续收益。此外,我们定性地表明 DAMEX 对专家表示崩溃 具有鲁棒性。
1. Introduction
建议专家混合不仅仅是可扩展的学习者,而是为数据集混合构建通用模型的有效和高效的解决方案。在 vanilla MoE 的基础上,我们引入了一种新的数据集感知混合专家模型 DAMEX,该模型学习解开 MoE 层的特定数据集特征,同时汇集非 MoE 层的信息。DAPEX学会将令牌路由到相应的专家,以便在推理过程中,它通过网络自动选择测试图像的最佳路径。我们使用基于DINO[42]的检测体系结构来开发我们的方法
contributions
1. 数据集感知的专家混合层( DAMEX ) :
(1). 提出了一种新的数据集感知的专家混合( Mixture-of-Experts, MoE )层,称为 DAMEX 。
(2). 该层能够在 MoE 层内 分离出数据集特定的特征 , 有效处理异构数据集的混合 。
(3). 通过实验表明, DAMEX 能够促进更优秀的专家利用率,并避免了传统 MoE 训练中常见的表示崩溃问题。
2. 无需测试时数据集标签 :
(1). 与基线相比, DAMEX 不需要在测试时提供数据集标签 。
(2). 模型在训练期间就学会了将适当的输入数据集路由到对应的专家,这使得模型在测试时更加健壮。
3. MoE 作为有效的学习者 :
(1). 据作者所知,这是 第一个探索 MoE 不仅仅是一个可扩展性工具,而是作为混合数据集 的有效学习者的工作。
(2). MoE 作为模型架构内的一种知识整合策略,与密集架构相比,参数数量的增加是边际的。
4. 在 UODB 数据集上建立新的最高标准 :
(1). DAMEX 在 UODB 数据集上的表现超过了之前报道的基线,平均 提高了 10.2 个百分点 的 AP 得分。
(2). 与非 MoE 基线相比,平均提高了 2.0 个百分点的 AP 得分。
(3). 在各种数据集混合场景中观察到一致的性能提升,如 1 )有限数据可用性, 2 )不同领域,以及 3 )同一领域内不同的标签集
3. 传统的MOE
3.1 Routing of tokens
令牌和专家表示: * 输入令牌表示为 x∈RDx∈RD,其中 DD 是令牌的维度。
* 专家集合表示为 {ei}i=1∣E∣{ei}i=1∣E∣,其中 ∣E∣∣E∣ 是专家的总数。
* 路由器变量表示为 Wr∈RE×DWr∈RE×D,这是一个权重矩阵,用于确定每个专家的选择概率。
2.
计算专家选择概率:
(1) 首先,计算每个专家的选择得分 gx:gx=Wr⋅x 这是路由器变量Wr 与输入令牌 x 的点积。
(2) 然后,计算每个专家 eiei 的选择概率 pi(x): pi(x)=∑j=1∣E∣exp(gxj)exp(gxi) 这里,exp(gxi) 是指指数函数 egxiegxi,分母是所有专家得分指数的总和。这个公式实现了一个softmax函数,它将每个专家的得分转换为一个概率分布,其中每个概率表示选择对应专家处理输入令牌的可能性。
令牌路由:
(1) 接下来,使用top-k策略来选择概率最高的k个专家来处理令牌。在论文中,kk 被设置为1,意味着只选择概率最高的一个专家。
(2)计算输出 yy 作为选定专家处理过的令牌的加权组合: y=∑i∈top-kpi(x)ei(x)y=∑i∈top-kpi(x)ei(x) 这里,ei(x)ei(x) 是被选中的专家对输入令牌 xx 进行处理后的输出。输出 yy 是根据每个专家被选中的概率 pi(x)pi(x) 加权后的结果。
通过这种方式,MoE模型能够将输入数据分配给最擅长处理该数据的专家,从而提高整个模型的效率和性能。这种方法也有助于提高模型的可扩展性和处理不同类型数据的能力。
3.2 负载均衡损失
Load balancing among the experts,如何在MoE模型中的专家之间进行负载平衡。以下是公式的解释:
1)专家的重要性损失(Importance Loss)
对于每个专家 e_i ,计算其重要性 I_i : 这里, \mathcal{M} 是一批输入令牌的集合, p_i(x) 是选择专家 处理令牌 x 的概率。重要性 表示专家 被选中的总次数。然后,计算重要性损失 \L_{\text{importance}} ): 这里, \text{Var}(I) 是重要性 I 的方差, \text{Mean}(I) 是 的均值。这个损失函数旨在最小化专家重要性的方差,从而确保所有专家的使用相对均衡。
这里给一个直观的计算方式:
假设我们有一个混合专家(MoE)模型,它有5个专家(Expert 1, Expert 2, Expert 3, Expert 4, Expert 5),并且我们有20个输入令牌。我们的目标是确保这些令牌在五个专家之间均匀分配,以实现负载均衡。
初始分配
假设初始分配如下:
- Expert 1: 5个token
- Expert 2: 3个token
- Expert 3: 2个token
- Expert 4: 5个token
- Expert 5: 5个token
这种分配方式导致了负载不均衡,因为Expert 1和Expert 4处理了更多的令牌,而Expert 3处理的令牌较少。
计算重要性损失
为了计算重要性损失,我们需要先计算每个专家的 I_i,即每个专家被分配到的令牌数量:
I_1 = 5
I_2 = 3
I_3 = 2
I_4 = 5
I_5 = 5
然后,我们计算所有专家的 I 的平均值(Mean)和方差(Var):
- 平均值 \text{Mean}(I) = \frac{5 + 3 + 2 + 5 + 5}{5} = \frac{20}{5} = 4
- 方差 \text{Var}(I) = \frac{(5-4)^2 + (3-4)^2 + (2-4)^2 + (5-4)^2 + (5-4)^2}{5} = \frac{1 + 1 + 4 + 1 + 1}{5} = \frac{8}{5} = 1.6
最后,我们使用这些值来计算重要性损失 \mathcal{L}_{\text{importance}}:
- \mathcal{L}_{\text{importance}} = \frac{\text{Var}(I)}{\text{Mean}(I)^2} = \frac{1.6}{4^2} = \frac{1.6}{16} = 0.1
这个重要性损失值表示了专家之间负载分配的不均衡程度。值越小,表示专家之间的负载越均衡。在我们的示例中,重要性损失为0.1,这表明专家之间的负载分配存在一定的不均衡。
结论:通过计算重要性损失,我们可以量化专家之间的负载不均衡程度,并采取相应的措施来重新分配令牌,以实现更好的负载均衡。这种方法有助于确保所有专家都能得到充分利用,避免资源浪费,并加速模型的收敛。
2)专家的负载损失(Load Loss)
对于每个专家 ,计算其负载 L_i : 这里, \Phi 是正态分布 N(0, \sigma^2 I) 的累积分布函数(CDF),其中 \sigma = \frac{\text{gate noise}}{|E|} 。负载 表示专家 被分配的令牌数量。然后,计算负载损失 L_{\text{load}} : 这里, \text{Var}(L) 是负载 L 的方差,\text{Mean}(L) 是 的均值。这个损失函数旨在最小化专家负载的方差,从而确保所有专家的负载相对均衡。
3)负载平衡的辅助损失(Load Balancing Auxiliary Loss):
最后,计算负载平衡的辅助损失 L_{\text{load-balancing}} :这个损失函数结合了重要性损失和负载损失,以确保在训练过程中专家的使用和负载都保持均衡。
通过引入这个负载平衡的辅助损失,MoE模型可以更有效地利用所有专家,避免某些专家过载而其他专家闲置,从而提高模型的整体性能和稳定性。
4. 方法
4.1 setup
我们从DINO架构的解码器中替换了交替的非MoE(Mixture-of-Experts)变换模块为MoE变换模块。在对象检测中混合数据集的一个常见问题是为每个数据集设置单独的检测头,这会在两阶段检测流程中显著增加参数数量。我们利用DINO的变换架构通过增加最后分类层的维度来处理类别数量,与总模型参数相比,参数数量的增加是边际的(约10,000),而总模型参数为46M。在所有实验中,我们使用一个专家对应一个GPU,这使得我们的参数数量与非MoE方法相同(除了增加了一个微小的路由器线性层)。
在传统的MoE中,与图像分类问题不同,每个视觉标记都用于损失计算,我们仅在检测任务中的前景标记上应用负载平衡损失。我们发现,仅前景损失平衡可以带来更好的梯度更新,并减轻检测任务中专家表示崩溃的问题。
4.2 DAMEX: Dataset-aware Mixture-of-Experts
DAMEX(Dataset-aware Mixture-of-Experts)的Loss计算方法涉及到一个辅助的交叉熵损失函数。这个损失函数的设计是为了训练MoE(Mixture-of-Experts)路由器,以便根据输入令牌的数据集来源将它们路由到相应的专家。
在DAMEX中,每个数据集被分配给一个特定的专家。这意味着,不是所有来自相同数据集的令牌都有相同的标签,而是它们被分配给特定的专家进行处理。这里的“标签”实际上是指目标专家的索引,而不是传统意义上图像分类任务中的类别标签。
举例来说,假设我们有三个数据集:COCO、DOTA和ImageNet。每个数据集都分配给了一个不同的专家。如果一个输入令牌来自COCO数据集,那么根据映射函数 h ,它将被分配给负责COCO数据集的专家,假设这个专家是专家1( e_1 )。因此,这个令牌的目标标签就是1,表示它应该被路由到专家1。路由器的任务是预测这个令牌应该被路由到哪个专家,预测的概率由给出,其中 i 是专家的索引,是输入令牌。
辅助损失 \mathcal{L}_{\text{DAMEX}} 的计算公式为:

在这个公式中, 1\left(h(d_m)=i\right) 是一个指示函数,当 h(d_m) 等于时,它的值为1,否则为0。这意味着只有当预测的专家索引与目标专家索引相匹配时,才会计算损失。是模型预测的第个专家被选中的概率。
通过这种方式,DAMEX训练路由器将所有来自特定数据集的视觉令牌发送到其对应的专家,从而确保MoE的有效利用,并避免表示崩溃。
4.3 Implementation details
- 使用预训练的ImageNet ResNet-50作为主干网络,输入所有方法的4个尺度特征。
- 对于超参数,使用6层变换器编码器和6层变换器解码器,隐藏特征维度为256。
- 使用TUTEL库修改DINO的解码器,交替使用MoE层。
- 使用容量因子f为1.25,辅助专家平衡损失权重为0.1,选择顶部1个专家(topk=1)。
- 每个GPU上有一个专家,使用8个RTX6000 GPU进行训练,每个GPU的批量大小为2,除非另有说明。
- 注意,模型的参数数量保持不变(除了增加了边际路由器参数),因为每个专家都位于单独的GPU上,替换了现有的前馈层。
