DOSFL:Distilled One-Shot Federated Learning

CoRR2020,这篇是结合了OneShot Federated Learning(CoRR2019)和Data Distill(CoRR2018),client端训练出一组合成数据,使得模型每次在真实数据的更新与在合成数据上的更新尽量一致,这组合成数据传输到server端训练全局模型。我主要是想看下如何在FL上做数据蒸馏。
论文地址:arxiv
code: 没找到
贡献
提出蒸馏一次性联邦学习(DOSFL),将通信成本降低多达3个数量级。
DOSFL 仅需要服务器与其客户端之间的一轮通信。每个客户端提取数据并将学习到的合成数据、标签和学习率上传到服务器,而不是传输大量的梯度或权重。即使包含数千个示例的大型数据集也可以压缩为仅几个虚构的示例。
算法流程
(1) 服务器初始化一个模型并广播给所有客户端。 (2) 每个客户端提取其私有数据集,(3) 将合成数据、标签和学习率传输到服务器。 (4) 服务器根据提取的数据拟合其模型,(5) 将最终模型分发给所有客户端。

流程不复杂,和FL不同之处在于,上传的是合成数据、标签和学习率,且只传输一次(我不明白为啥要做成one-shot,是重复训练没有提升吗?)

初始化:server端模型随机初始化,client端初始合成数据服从正态分布,学习率初始设定为已知值,初始标签为one-hot(分类任务)或正态分布向量(回归任务),初始模型参数与server端一致。
client更新:先用本地真实数据更新模型参数θ,然后根据更新前后的θ差更新三元组 合成数据、标签和学习率(每次是和前一次的θ算差)。总共有Sd个三元组,每组更新Ed次(就是用合成数据来模拟真实数据更新)。soft label就是知识蒸馏中的软标签,对应hard label就是分类中的one-hot编码。

server更新:这个就是所有客户端的三元组序列传上来,交叉一下,然后正常训练模型。
文中说在多个客户端的时候会有一个问题,比如两个客户端传输的三元序列组{x1,x2,x3,x4…}和{y1,y2,y3,y4…},组成的训练server模型的三元序列组{x1,y1,x2,y2,…},用这组训练server端的模型时,假设数据x1和y1都是将网络从θ0更新成θ1,那这样的话用{x1,y1,…}这个三元组更新的模型和所预想的不一样。文中提出soft resets 和 random mask 两种解决方案。
soft resets: “hard resets”就是 data distill 那篇文章,重新初始化θ,soft就是根据高斯分布再采样。(这地方没看懂,本来就是高斯分布采样了,这是二次采样还是怎么回事?)
random mask: 在client端训练阶段合成数据随机变成随机张量(一个图像某一部分变张量),迭代完成后再复原
疑问
有几个地方不太明白,而且没有代码。。
每个client数据不一致导致它们收敛所需的训练次数也不一致,那么迭代次数Ed是定值吗? 如果是一直训练到收敛的话,模型在真实数据上收敛的时候,在合成数据上不一定是收敛的。此时用合成数据训练收敛得到的模型和在真实数据上训练收敛的模型就不一致了。
每个client数据不一致导致它们收敛的梯度方向也不一致,如何得知这些更新不会相互干扰?或者说为什么能断定这是必定收敛的?(data distill 那篇文章没有用到多个合成数据集,只在当个合成数据集上和真实数据集做比较)
文中自己也说了多个client端时会出现梯度更新不合理的情况(文中直接默认了所有客户端更新梯度方向一致?),但是提出的soft resets方法没看懂啥意思,而random mask这种操作真的能提升鲁棒性吗?
在client端合成数据的更新方式,最后合成数据是拟合了最后一次梯度更新的真实数据,但这最后一次更新的合成数据还能拟合第一次梯度更新的真实数据吗??
这样说有点拗口,我们最后一次更新三元组得到的合成数据,经过最后一次网络θn计算损失得到的梯度返回,和真实数据经过最后一次网络θn计算损失得到的梯度返回,这二者保持一致。同理,我们第一次更新三元组得到的合成数据,经过第一次网络θ0计算损失得到的梯度返回,和真实数据经过第一次网络θ0计算损失得到的梯度返回,这二者保持一致。但是真实数据从第一次到最后一次都没有变过,而合成数据每次都在更新变化,最后我们返回的三元组序列或许能模拟网络最后一次更新,但估计是满足不了前面几次的更新。

