3.3 用Python实现横向联邦图像分类

本节我们使用Python从零开始实现一个简单的横向联邦学习模型。具体来说,我们将用横向联邦来实现对cifar10图像数据集的分类,模型使用的是ResNet-18。我们将分别从服务端、客户端和配置文件三个角度详细讲解设计一个横向联邦所需要的基本操作。

需要注意的是,为了方便实现,本章没有采用网络通信的方式来模拟客户端和服务端的通信,而是在本地以循环的方式来模拟。在第10章中,我们将介绍利用Flask-SocketIO模拟客户端和服务端进行网络通信的实现。

3.3.1 配置信息

联邦学习在开发过程中会涉及大量的参数配置,其中比较常用的参数设置包括以下几个。

• 训练的客户端数量:每一轮的迭代,服务端会首先从所有的客户端中挑选部分客户端进行本地训练。每一次迭代只选取部分客户端参与,并不会影响全局收敛的效果,且能够提升训练的效率[200]

• 全局迭代次数:即服务端和客户端的通信次数。通常会设置一个最大的全局迭代次数,但在训练过程中,只要模型满足收敛的条件,那么训练也可以提前终止。

• 本地模型的迭代次数:即每一个客户端在进行本地模型训练时的迭代次数。每一个客户端的本地模型的迭代次数可以相同,也可以不同。

• 本地训练相关的算法配置:本地模型进行训练时的参数设置,如学习率(lr)、训练样本大小、使用的优化算法等。

• 模型信息:即当前任务我们使用的模型结构。在本案例中,我们使用ResNet-18图像分类模型[127]

• 数据信息:联邦学习训练的数据。在本案例中,我们将使用cifar10数据集。为了模拟横向建模,数据集将按样本维度,切分为多份不重叠的数据,每一份放置在每一个客户端中作为本地训练数据。

其他的配置信息,比如可能使用到的加密方案、是否使用差分隐私、模型是否需要检查点文件(checkpoint)、模型聚合的策略等,都可以根据实际需要自行添加或者修改。我们将上面的信息以json格式记录在配置文件中以便修改,如下所示。

联邦学习在模型训练之前,会将配置信息分别发送到服务端和客户端中保存,如果配置信息发生改变,也会同时对所有参与方进行同步,以保证各参与方的配置信息一致。

3.3.2 训练数据集

按照上述配置文件中的type字段信息,获取数据集。这里我们使用torchvision的datasets模块内置的cifar10数据集。如果要使用其他数据集,读者可以自行修改。

3.3.3 服务端

横向联邦学习的服务端的主要功能是将被选择的客户端上传的本地模型进行模型聚合。但这里需要特别注意的是,事实上,对于一个功能完善的联邦学习框架,比如我们将在后面介绍的FATE平台,服务端的功能要复杂得多,比如服务端需要对各个客户端节点进行网络监控、对失败节点发出重连信号等。本章由于是在本地模拟的,不涉及网络通信细节和失败故障等处理,因此不讨论这些功能细节,仅涉及模型聚合功能。

下面我们定义一个服务端类Server,类中的主要函数包括以下三种。

• 定义构造函数。在构造函数中,服务端的工作包括:第一,将配置信息拷贝到服务端中;第二,按照配置中的模型信息获取模型,这里我们使用torchvision的models模块内置的ResNet-18模型。torchvision内置了很多常见的模型(链接3-5)。模型下载后,令其作为全局初始模型。

• 定义模型聚合函数。前面我们提到服务端的主要功能是进行模型的聚合,因此定义构造函数后,我们需要在类中定义模型聚合函数,通过接收客户端上传的模型,使用聚合函数更新全局模型。聚合方案有很多种,本节我们采用经典的FedAvg算法[200]。FedAvg算法通过使用下面的公式来更新全局模型:

其中,Gt表示第t轮聚合之后的全局模型,表示第i个客户端在第t+1轮本地更新后的模型,Gt+1表示第t+1轮聚合之后的全局模型。算法代码如下所示。

• 定义模型评估函数。对当前的全局模型,利用评估数据评估当前的全局模型性能。通常情况下,服务端的评估函数主要对当前聚合后的全局模型进行分析,用于判断当前的模型训练是需要进行下一轮迭代、还是提前终止,或者模型是否出现发散退化的现象。根据不同的结果,服务端可以采取不同的措施策略。

3.3.4 客户端

横向联邦学习的客户端主要功能是接收服务端的下发指令和全局模型,并利用本地数据进行局部模型训练。

与前一节一样,对于一个功能完善的联邦学习框架,客户端的功能也相当复杂,比如需要考虑本地的资源(CPU、内存等)是否满足训练需要、当前的网络中断、当前的训练由于受到外界因素影响而中断等。读者如果对这些设计细节感兴趣,可以查看当前流行的联邦学习框架源代码和文档,比如FATE,获取更多的细节。

本节我们仅考虑客户端本地的模型训练细节。我们首先定义客户端类Client,类中的主要函数包括以下两种。

• 定义构造函数。在客户端构造函数中,客户端的主要工作包括:首先,将配置信息拷贝到客户端中;然后,按照配置中的模型信息获取模型,通常由服务端将模型参数传递给客户端,客户端将该全局模型覆盖掉本地模型;最后,配置本地训练数据,在本案例中,我们通过torchvision的datasets模块获取cifar10数据集后按客户端ID进行切分,不同的客户端拥有不同的子数据集,相互之间没有交集。

• 定义模型本地训练函数。本例是一个图像分类的例子,因此,我们使用交叉熵作为本地模型的损失函数,利用梯度下降来求解并更新参数值,实现细节如下面代码块所示。

3.3.5 整合

当配置文件、服务端类和客户端类都定义完毕后,我们将这些信息组合起来。首先,读取配置文件信息。

接下来,我们将分别定义一个服务端对象和多个客户端对象,用来模拟横向联邦训练场景。

每一轮的迭代,服务端会从当前的客户端集合中随机挑选一部分参与本轮迭代训练,被选中的客户端调用本地训练接口local_train进行本地训练,最后服务端调用模型聚合函数model_aggregate来更新全局模型,代码如下所示。

模型聚合完毕后,调用模型评估接口来评估每一轮更新后的全局模型效果。完整的代码请参见本书配套的GitHub网页。