FlexiFed: Personalized Federated Learning for Edge Clients with Heterogeneous Model Architectures笔记
FlexiFed: Tailored Federated learning for Edge Clients supporting Diverse Model Architectures笔记
会议: WWW '23: Proceedings of the ACM Web Conference 2023
目录
FlexiFed: A Customized Federated Learning Approach Tailored for Heterogeneous Model Architectures in Edge Clients Notes
前言
一、PRELIMINARIES
1.目标函数
2.架构异质性
二、FLEXIFED设计
1.Basic-Common Strategy
a.原理
b.算法
2.Clustered-Common策略
a.启发
b.原理
c.伪代码
3.Max-Common Strategy
a.原理
b.伪代码
三、实验
1.三个ML任务,四个数据集
2.四个ML模型
3.Baselines
4.实验结果
a.准确度
b.收敛性
c.Max-commn下模型数量的影响
d.客户端数量K的影响
e.公共层数的影响
f.全连接层是否考虑
四、代码复现
2.策略
3.模型--VGG
前言
为了适应体系结构的异质性(模型异质性):支持运行不同体系结构的ML模型客户端以适应体系结构异质性问题。本文提出了一个名为FlexiFed的新方案:通过促进客户端公共基础层之间的协作培训来实现对体系结构异质性的适应。
一、PRELIMINARIES
1.目标函数

其中k-th客户端的局部损失函数为:

其中,在每次迭代中使用(x_i,y_i)表示训练数据集中的实例,并通过L(𝜔;x_i,y_i)这一交叉熵损失度量评估预测结果与真实标签之间的差异程度。
2.架构异质性
客户的模型共享相同的公共层

,但可能在其个人层中有所不同

。设

表示FL系统中客户端𝑢的本地模型。它包含公共基础层

和个人层

。通过聚合客户端的公共基础层,参数服务器可以生成全局公共基础层

。
二、FLEXIFED设计
1.Basic-Common Strategy
a.原理
迁移学习的主要目标是将底层神经网络所学到的关键特征转移到另一个神经网络中。这一原理基于深度神经元网络在低层能有效传递关键特征的能力,这是因为它们主要致力于提取共性和通用模式,而高层则专注于捕捉具体细节和领域特性。
此构成了FlexiFed的核心理念——协同训练中的共享基础架构模块。该方案旨在使绝大多数体系结构异构的应用程序能够让各端设备从彼此中获取共性知识以提升性能,并加速收敛速度。核心在于参数服务器应采取适当的方法来整合各端设备的知识点,并识别共享模块及其在不同端设备间的整合方式。


b.算法

、
只聚合公共层
2.Clustered-Common策略
a.启发
聚类

b.原理
除了如Basic-Common这类专注于客户聚合的公共基础层之外,在FlexiFed框架下实施Clustered-Common的具体训练流程主要包括四个关键阶段。

c.伪代码

公共层面与base保持一致性。随后将具备个体一致性的一群个体归并成一个群体,并将其划分为C个子群体。接着对各子群体内部的所有客户个体进行个人层面的一致性聚类,并生成基于各子群体的全球模型

。
时间复杂度是指在FL系统中处理K个客户端。在最坏情况下(即所有个人层都不相同时),聚类过程的时间复杂度为O(K²)。在最佳情况下(即K−1个个人层完全相同),聚类过程的时间复杂度为O(K).
缺点明显存在
3.Max-Common Strategy
a.原理
通过客户的共同基础层和共同个人层实现知识共享最大化。

b.伪代码

步骤3:参数server将客户端按组进行划分,在同一分组中的客户端需确保其剩余层次内至少共享一个共同的基础结构;而不同分组中的客户端则需避免在各自剩余层次内有任何共同的基础结构,并对各个分组分别应用Basic-Common架构;随后参数服务器会对每个分组递归执行上述操作直至所有客户机的个人层级都已彻底处理完毕(第12至13行)。
三、实验
1.三个ML任务,四个数据集
- 图像分类: 使用CIFAR-10与CINIC-10数据集:其中CIFAR-10数据集包含5万张训练样本以及每个类别对应的一万张测试样本。而CINIC-10则包含了来自十个类别共计27万张图像的数据集。
这些模型基于AG News数据集进行了训练[65]。该数据集包含了四种新闻类别,并且每一大类中都拥有3万条训练样本以及1千9百条测试样本的数量分布。
语音识别技术
该联邦学习系统包含40个客户端。将训练集划分为均匀分布的40个区域,在独立同分布(IID)环境下,每个客户端将随机获得一个区域;而在非独立同分布(non-IID)情况下,则会根据类别分配到多个区域。
2.四个ML模型
- 图像分类:基于不同版本的VGG网络(包括VGG-11、VGG-13、VGG-16及VGG-19)以及ResNet系列网络(包括ResNet-20、ResNet-32、ResNet-44及ResNet-56),我们对其进行了系统的优化与改进,并分别应用于CIFAR-10数据集与CINIC-10数据集上的实验研究。
- 文本分类:CharCNN系列模型(涵盖长度为3到6维的CharCNN)以及VDCNN网络(包含9层、17层、29层和49层的设计),经改进后成功应用于AG News语料库的数据集上。
- 语音识别:通过深度学习算法对经典的VGG模型与ResNet模型进行了参数优化与结构改进,在标准的语音命令识别测试集上取得了显著性能提升。
3.Baselines
- Standalone:各个客户独立在本地构建并训练各自的模 型;无需与其他客户端进行交互或知识共享。
- Clusted-FL:各个 客户端独立构建并训练本机上的局部模 型;并将各 客户端的更新信息上传至 参数 服务器进 行汇总; 参数 服务器则依据 各 客户端架构特征将其划分为若干 组别,并采用FedAvg方 法生成相应的 全局基准模 型。
4.实验结果
a.准确度

b.收敛性

表4展示了BasicCommon、Clustered-Common以及Max-Common相对于Clustered-FL的详细比较结果。经对比分析可知,在加速效果方面Max-Common表现最为突出

c.Max-commn下模型数量的影响

例如,在图6(a)的情景下
当N值增大时(即随着N的增长),收敛速度减慢但精度提升)。在计算结果等于1的情况下(即当输出结果为1时),十个客户端均运行同一模型,并可在每层中分享该模型所包含的所有知识。这使得初始阶段模型准确率迅速提升。然而,在知识逐渐被客户训练模型超越后(即随着客户的本地训练逐步增强其本地模型的能力),各客户端所分享的知识有限
d.客户端数量K的影响

当K值不断增加时,在各种情况下系统的模型精度均呈上升趋势。当客户数量增多时(即系统中的客户数目不断增加),基于Max-Common的方法能够实现更高的系统性能。研究结果表明,在异构架构模型中通过Max-Common方法能够有效提取并实现各参与方的知识共享。
e.公共层数的影响

当𝐿值不断增大时(如𝐿趋近于无穷大),模型不仅在精度上表现出了显著提升,在收敛速度方面也展现出了更快的速率。研究表明基于最大的公共基础层能够有效促进客户间的知识共享效率
f.全连接层是否考虑

在引入全连接层到模型聚合过程中后,在测试集上的准确率有所降低;特别是在针对AG新闻数据集中的VDCNN模型时这一现象尤为明显。这反映出这些模型中的完全连接层并未有效促进知识共享。
g.Non-IID

虽然精度比聚类要高,但是整体准确度太低,有可提升的空间。
四、代码复现
提供的代码不全面。复现cifar10的iid的VGG情况,效果不如文章。

1.main.py:
if __name__ == '__main__':
device = torch.device("cpu")
# load dataset and user groups
train_dataset, test_dataset, user_groups, idx_test = get_dataset("cifar10")
# Training
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
number_device = 8
idxs_users = [_id for _id in range(number_device)]
modelAccept = {_id: None for _id in range(number_device)}
for _id in range(number_device):
if _id < 2:
modelAccept[_id] = vgg11_bn()
elif _id >= 2 and _id < 4:
modelAccept[_id] = vgg13_bn()
elif _id >= 4 and _id < 6:
modelAccept[_id] = vgg16_bn()
else:
modelAccept[_id] = vgg19_bn()
localData_length = len(user_groups[0]) / 10
start = 0
local_acc = [[] for i in range(number_device)]
for epoch in range(50):
print(f'\n | Global Training Round : {epoch+1} |\n')
end = start + localData_length
for idx in idxs_users:
idx_train_all = list(user_groups[idx])
idx_train_batch = set(idx_train_all[int(start):int(end)])
if epoch == 0:
model = modelAccept[idx]
if epoch > 0:
if idx < 2:
model = vgg11_bn()
# model.load_state_dict(A)
elif idx >= 2 and idx < 4:
model = vgg13_bn()
# model.load_state_dict(B)
elif idx >= 4 and idx < 6:
model = vgg16_bn()
# model.load_state_dict(C)
else:
model = vgg19_bn()
# model.load_state_dict(D)
model.load_state_dict(modelAccept[idx])
acc = test_inference(model, test_dataset,
list(idx_test[idx]), device)
local_acc[idx].append(round(acc, 2))
if epoch % 10 == 0:
print(local_acc[idx])
Model = copy.deepcopy(model)
localModel = local_train(
Model, train_dataset, idx_train_batch, device)
modelAccept[idx] = copy.deepcopy(localModel)
start = end % 2500
modelAccept = common_max(modelAccept)
# modelAccept, _ = common_basic(modelAccept)
2.策略
def FedAvg(w):
w_avg = copy.deepcopy(w[0])
for k in w_avg.keys():
for i in range(1, len(w)):
w_avg[k] += w[i][k]
w_avg[k] = torch.div(w_avg[k], len(w))
return w_avg
def common_basic(w):
minIndex = 0
minLength = 10000000
for i in range(0, len(w)):
if len(w[i]) < minLength:
minIndex = i
minLength = len(w[i])
commonList = [s for s in w[minIndex].keys()]
for i in range(0, len(w)):
local_weights_names = [s for s in w[i].keys()]
for j in range(len(commonList)):
if commonList[j] == local_weights_names[j]:
continue
else:
del commonList[j:len(commonList)+1]
break
for k in commonList:
comWeight = copy.deepcopy(w[0][k])
for i in range(1, len(w)):
comWeight += w[i][k]
comWeight = comWeight / len(w)
for i in range(0, len(w)):
w[i][k] = comWeight
return w, commonList
def common_max(w):
w_copy = copy.deepcopy(w)
count = [[] for i in range(len(w))]
for i in range(len(w)):
local_weights_names = [s for s in w[i].keys()]
count[i] = [1 for m in range(len(local_weights_names))]
for i in range(0, len(w)):
local_weights_names1 = [s for s in w[i].keys()]
for j in range(i+1, len(w)):
if i == j:
continue
local_weights_names2 = [s for s in w[j].keys()]
for k in range(0, len(local_weights_names1)):
if local_weights_names2[k] == local_weights_names1[k]:
name = local_weights_names1[k]
w[i][name] += w_copy[j][name]
w[j][name] += w_copy[i][name]
count[i][k] += 1
count[j][k] += 1
else:
break
for c in range(0, len(w)):
local_weights_names = [s for s in w[c].keys()]
for k in range(0, len(local_weights_names)):
w[c][local_weights_names[k]] = w[c][local_weights_names[k]].cpu() / \
count[c][k]
return w
3.模型--VGG
__all__ = [
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19_bn', 'vgg19',
]
class VGG(nn.Module):
'''
VGG model
'''
def __init__(self, features):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(512, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Linear(4096, 10),
)
#fc: 1024, 4096, 512, 96
# Initialize weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
m.bias.data.zero_()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers)
cfg = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',
512, 512, 512, 512, 'M'],
}
