Advertisement

【论文笔记】Personalized Federated Learning with Theoretical Guarantees: A MAML Approach

阅读量:

这项研究提出了一种名为Per-FedAvg的联邦个性化学习方法,受MAML启发,结合了FedAvg算法。其创新性在于通过找到一个共享的初始点,使每个用户只需执行少量梯度下降步骤即可适应本地数据,从而保留了联邦学习的结构优势并实现了个性化。研究贡献包括详细阐述了Per-FedAvg与FedAvg的联系、讨论了其收敛特性、提出了基于分布距离的异质性度量,并通过实验验证了其有效性。实验结果表明,Per-FedAvg在测试准确率上优于传统FedAvg及其变体,特别是在异构数据集上的性能表现尤为突出。

Personalized Version of Federated Learning with Theoretical Guarantees: A Model-Agnostic Meta-Learning Framework(Personalized Federated Learning)

  • 创新特性分析
  • 本文的主要贡献体现在算法创新性、适用性和实验有效性等方面
  • 详细阐述了Per-FedAvg算法的具体步骤
  • 针对评估指标,通过一系列实验得出了相应的结果
  • 评估指标
    • 实验数据集的选取
    • 采用多种对比方法进行分析
    • 通过实验分析,得出了各对比方法的性能评估结论

Guarantees: A Model-Agnostic Meta-Learning
Approach(联邦个性化元学习))

创新性

传统联邦学习的主要目标是跨越多个计算单元(用户)训练模型,这些用户与一个公共的中央服务器进行通信,而无需直接交换各自的原始数据样本。通过整合所有用户的计算能力,联邦学习方法在更大规模的数据集上进行训练,从而积累更丰富的模型经验。这种机制仅提供一个统一的输出,而没有针对每个用户进行个性化适应。特别地,在用户的基础数据分布存在显著差异的异构环境中,通过最小化平均损失获得的全局模型一旦应用于每个用户的本地数据集,可能会导致性能下降。本文研究了一种新型的个性化联邦学习方法,受到MAML的启发,将其与传统的FedAvg方法相结合,提出了Per-FedAvg方案。其核心目标是确定一个在所有用户之间共有的初始参数点,使得当前或新用户的本地数据集经过一到几个梯度下降步骤即可快速适应。通过这种方法,Per-FedAvg不仅保留了联邦学习体系结构的所有优势,还为每个用户提供了高度个性化的模型参数。

本篇论文的贡献

深入探讨了Per-FedAvg算法的个性化变体,旨在解决提出的个性化联邦学习问题。该研究主要包含以下三个关键内容:首先,详细阐述了Per-FedAvg与原始FedAvg算法之间的联系,并探讨了实施该算法时需要考虑的关键因素。其次,深入分析了Per-FedAvg算法在非凸损失函数下的收敛特性。最后,详细探讨了基于分布距离的衡量方法,分析了其对Per-FedAvg算法性能的影响。

Per-FedAvg算法步骤

①FedAvg服务器优化目标:

在这里插入图片描述

其中,f_{i}表示用户的本地损失函数:

在这里插入图片描述

Per-FedAvg:假设每个用户从自身数据集中初始化模型参数,并通过单步梯度下降更新其损失函数,服务器的优化目标相应地转变为:

在这里插入图片描述

其中,定义为与用户i相关联的变量函数F_{i}(w)表示为:

在这里插入图片描述

第一步:在每一轮中,服务器会随机选择一定比例(比例系数r满足0 < r ≤ 1)的用户,并将这些用户当前训练模型参数w_{k}传递给他们。
第二步:通过本地训练集对元模型进行一次参数更新。其中,β表示本地优化的学习率,w^i_{k+1,t}表示用户i在第k+1轮的本地迭代结果,其中t为本地迭代次数。经过一次优化后,得到更新后的参数结果。

在这里插入图片描述

上式中需要对①中原函数F_{i}(w)进行求导:

在这里插入图片描述

在每一轮训练中,基于用户提供的全部训练数据计算梯度 ∇f_{i}(w),这一过程通常会带来较高的计算成本。因此,为了获得无偏估计,我们从分布 p_{i}中选取一批代表性的样本来构造数据集 D_{i}

在这里插入图片描述

对于每一次估计,均采用独立批次的训练数据,从而能够获得原函数F_{i}(w)的导数表达式。

在这里插入图片描述

第三步:通过一组测试数据集对元模型进行参数更新,计算该损失对w^i_{k+1,t}的梯度,并基于此梯度更新服务器模型的参数。

{w}_{k+1, t}^{i}=w_{k+1, t-1}^{i}-\alpha \tilde{\nabla} f_{i}\left(w_{k+1, t-1}^{i}, \mathcal{D}_{t}^{i}\right)

其中\alpha为服务器模型的学习率。
算法如下:

在这里插入图片描述

评估指标和实验结果

评估指标

测试准确率

实验数据集

MNIST和CIFAR-10

对比方法

FedAvg 代表一种基础的联邦平均算法。其中,②和③分别代表Per-FedAvg的近似计算方式,FO和HF分别对应不同的计算策略。

实验结果

在这里插入图片描述

全部评论 (0)

还没有任何评论哟~