Skip to main content

FEDER: 通信高效的 Byzantine-robust 联邦学习

初识

这篇文章是关于一种新的联邦学习方法,叫做FedER,它既能提高通信效率,又能抵抗恶意客户端的攻击。文章介绍了FedER的主要思想和技术细节,包括 Mutual Masking余弦相似度Occasional NormalizationGlobal Model Updating。文章还展示了FedER在四个真实数据集上的实验结果,证明了它能够在减少通信开销的同时,达到与最先进的鲁棒方法相当甚至更好的鲁棒性。

FedER的核心思想是通过对模型更新进行剪枝,减少通信开销,同时利用服务器的小数据集和余弦相似度来增强对恶意客户端的鲁棒性。

相知

FedER 的算法流程

  1. 服务器收集一小部分干净的数据集,分为训练集和验证集,同时维护一个服务器模型。
  2. 在每一轮联邦学习中,服务器根据验证集计算一个剪枝比例p,并根据训练集计算一个服务器模型更新。
  3. 服务器向客户端发送当前的全局模型,并要求客户端在本地数据上训练并返回模型更新。
  4. 客户端在返回模型更新之前,根据剪枝比例p对模型更新进行剪枝,只保留最大的1-p比例的参数。
  5. 服务器收到客户端的剪枝后的模型更新后,用自己的剪枝后的模型更新与之进行 Mutual Masking,得到 Mutual Masking 后的模型更新。
  6. 服务器计算 Mutual Masking 后的模型更新之间的余弦相似度,并用ReLU函数进行截断,得到每个客户端的权重。
  7. 服务器偶尔对 Mutual Masking 后的模型更新进行归一化,以限制攻击者的影响。
  8. 服务器根据客户端的权重和归一化后的 Mutual Masking 后的模型更新计算全局模型更新,并用它来更新自己和所有客户端的本地模型。

Mutual Masking

Mutual masking 是一种在联邦学习中使用的技术,它可以在不共享数据的情况下,协同训练深度学习模型。它可以解决联邦学习中存在的三种异质性问题:数据的非独立同分布性、设备的计算能力差异和通信带宽限制。Mutual masking 可以通过选择最重要或最相关的参数来减少通信开销,并通过过滤掉恶意更新中可能携带后门触发器的参数来提高模型的鲁棒性。Mutual masking 还可以实现个性化的联邦学习,使每个客户端根据自己的数据特征和偏好来更新模型。

FEDER 通过选择更新中值最大的一部分参数来过滤掉不重要或被污染的参数。具体来说,服务器和客户端都会计算一个二进制掩码,用于标记更新中保留的参数位置。然后,服务器和客户端之间的掩码进行逐元素乘法,得到一个相互掩码,用于限制全局更新的范围。

ReLU-Clipped Cosine Similarity

ReLU-Clipped Cosine Similarity 是一种用于衡量两个向量之间的夹角的方法,也可以用于表示服务器模型更新和客户端模型更新是否具有相似的方向。它是在余弦相似度的基础上,对负值进行截断,使其为零。这样可以过滤掉一些恶意的或者不一致的模型更新,提高联邦学习中的鲁棒性。

ReLU-Clipped Cosine Similarity 的公式可如下描述:

Cmi=max(0,Gsmmi,GcimmiGsmmi2Gcimmi2)C_m^i = \max(0, \frac{\langle G_s \odot m_m^i, G_c^i \odot m_m^i\rangle}{|G_s \odot m_m^i|_2 |G_c^i \odot m_m^i|_2})

其中,CmiC_m^i 是服务器模型更新 GsG_s 和客户端模型更新 GciG_c^iReLU-Clipped Cosine Similarity,\langle \cdot,\cdot \rangle 表示两个向量的内积,2|\cdot|_2 表示向量的二范数,\odot 表示向量的逐元素乘积,mmim_m^i 是服务器和客户端模型更新的相互掩码。

Occasional Normalization

Occasional Normalization 是一种在联邦学习中用来限制恶意客户端更新影响的方法。 它的具体步骤是:

  1. 服务器根据自己的小数据集计算一个服务器模型更新,并对其进行剪枝,只保留最重要的参数。
  2. 服务器收集客户端的模型更新,并对每个客户端的更新进行剪枝和掩码,只保留和服务器更新重叠的参数。
  3. 服务器根据每个客户端更新和服务器更新之间的余弦相似度,给每个客户端一个权重。
  4. 服务器在每个周期内随机选择一个客户端,对其更新进行归一化,使其和服务器更新的范数相等。
  5. 服务器根据客户端的权重和归一化后的更新,计算一个全局更新,并用它来更新服务器和客户端的模型。

Occasional Normalization 的目的是减少通信开销,提高鲁棒性,防止拒绝服务攻击和后门攻击。

Global Model Updating

Global Model Updating 是指服务器根据客户端的模型更新,计算一个全局更新,并用它来更新服务器和客户端的模型。

回顾

  • 文章的实验设置比较简单,只考虑了两种后门攻击的场景,没有对不同的攻击强度和攻击者比例进行分析。
  • 文章的更新剪枝方法可能会导致模型的收敛速度变慢,或者模型的性能下降,没有对这些影响进行评估。
  • 文章的 Mutual Masking 机制可能会过滤掉一些正常的更新,或者放过一些恶意的更新,没有对这些误差进行量化。
  • 文章的 Occasional Normalization 策略可能会引入一些噪声,或者改变更新的方向,没有对这些变化进行分析。

后续的文章对这篇文章的改进可能有以下几个方面:

  • 在更复杂和更真实的后门攻击的场景下,对FedER的鲁棒性和通信效率进行评估。
  • 探索不同的剪枝方法和剪枝比例,以平衡模型的性能和通信的开销。
  • 设计更精确和更灵活的掩码机制,以适应不同的攻击模式和数据分布。
  • 研究不同的归一化方法和归一化频率,以减少归一化的负面影响