4、基于pytorch的分布式包的单机多核CPU实现横向的联邦学习

一、概述

首先感谢代码大佬的分享,我也是在大佬的基础上进行学习。此文既是笔记,也是分享,请结合代码一起食用。代码下载在第三节中给出链接。项目中使用多核CPU开启多进程模拟多机之间互相进行联邦学习的。

这是博主写的一些准备:https://blog.csdn.net/zzxxxaa1/article/details/121421075
里面介绍了python分布式的一些基础知识,如果对分布式和多线程不了解的话,这里十分重要。

25.jpg

二、系统介绍

(1)总体实现方式

使用多进程的方式来模拟server和worker之间的关系。其中主进程既要有server的功能,又要有client的功能。代码中的world_size就是进程数,即包括server在内共有多少个节点。

服务端 rank = 0

横向联邦学习的服务端的主要功能是将被选择的客户端上传的本地模型进行模型聚合。但这里需要特别注意的是,事实上,对于一个功能完善的联邦学习框架,比如我们FATE平台,服务端的功能要复杂得多,比如服务端需要对各个客户端节点进行网络监控、对失败节点发出重连信号等。此文由于是在本地模拟的,不涉及网络通信细节和失败故障等处理,因此不讨论这些功能细节,仅涉及模型聚合功能。

客户端 rank != 0

横向联邦学习的客户端主要功能是接收服务端的下发指令和全局模型,利用本地数据进行局部模型训练。与前一节一样,对于一个功能完善的联邦学习框架,客户端的功能同样相当复杂,比如需要考虑本地的资源(CPU、内存等)是否满足训练需要、当前的网络中断、当前的训练由于受到外界因素影响而中断等。读者如果对这些设计细节感兴趣,可以查看当前流行的联邦学习框架源代码和文档,比如FATE,获取更多的实现细节。此文我们仅考虑客户端本地的模型训练细节。我们首先定义客户端类Client,类中的主要函数包括以下两种。

  • 定义构造函数。在客户端构造函数中,客户端的主要工作包括:首先,将配置信息拷贝到客户端中;然后,按照配置中的模型信息获取模型,通常由服务端将模型参数传递给客户端,客户端将该全局模型覆盖掉本地模型;最后,配置本地训练数据,在本案例中,我们通过torchvision 的datasets 模块获取cifar10 数据集后按客户端ID切分,不同的客户端拥有不同的子数据集,相互之间没有交集。
  • 定义模型本地训练函数。本例是一个图像分类的例子,因此,我们使用交叉熵作为本地模型的损失函数,利用梯度下降来求解并更新参数值,实现细节如下面代码块所示。
(2)系统流程
  • 获取配置参数(里面包括了模型选取,学习率,batch size,epoch等)
  • 构建神经网络,获取权重参数(引入先前训练好的参数,帮助神经网络快速拟合。)
  • 创建多线程,分别模拟服务端(server)和工作节点(worker)
  • 进行全局迭代,每一轮迭代都要更新参数。
  • 迭代结束,关闭进程组

关于系统整体流程,我做成了流程图看得更清晰。

26.png

(3)文件目录介绍

.extra_util:

  • distributed_utils.py:用于构建进程和进程间的通信。
  • model.py:关于模型构建的类定义。
  • my_dataset.py:用于获取数据集长度、初始化、获取图片等。
  • utils.py:用于数据集划分和图片显示,将图片和分类保存成json格式。

Federal_defense

  • class_mean.py:网络构建,制作数据集,做了一个小的demo来查看分布式的计算结果
  • client_mean.py:这也是一个小demo,主要区别在local train里面。

根目录下的文件:

  • client.py: 构建客户端函数,并实现初始化函数、训练函数和测试函数。
  • distributed_minist.py: 这是一个minist数据集的一个测试demo。
  • get_weight_matrix.py: 不同进程之间的类平均值的通信。
  • main.py: 主函数,完整的流程。
  • my_data_loader.py: 划分数据集和数据增强。

! 这里面的文件结构非常杂乱有些引用都没有用到,但是可以看得出作者的一点一点的研究过程,每个文件都可以去看看,或许会给你很多启发。

三、环境搭建和代码运行

(1)环境搭建

Python、pytorch

(2)代码运行

代码下载地址:https://gitee.com/wang-qifan-hust/pytorch_federated_learning?_from=gitee_search

在根目录中执行下面命令:python main.py


Figure_1.png
Figure_2.png
Figure_3.png

四、部分代码分析

(1)分布式实现代码( 伪代码,不可运行 )
import os
import copy
import socket
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.multiprocessing import Process
from torch.nn.parallel import DistributedDataParallel as DDP
from client import *
from Federal_defense.class_mean import Net


def run(rank, size):
    """ 这里是伪代码,简单描述系统流程 """
    model = Net() #构建神经网络
    model.load_state_dict(weights, strict=False)  # 载入权重
    client = Client(args, rank)
    dist.barrier()  # 确保所有进程所对应的客户端client类实例化完成。

    for i in epochs:
        # 复制模型本地训练
        local_model = copy.deepcopy(model)
        diff = client.local_train(local_model)
        # 更新所有进程参数
        dist.all_reduce(value, op=dist.ReduceOp.SUM)
        # 等待所有进程完成此操作
        dist.barrier()
        # 更新model的参数
        model

    #阻塞进程,等待所有进程都运行到这里
    dist.barrier()
    dist.destroy_process_group()  # 释放进程组


def init_processes(rank, size, fn, backend='gloo'):
    """ 初始化分布式环境 """
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)  # 无线局域网适配器WLAN所对应的IPV4地址
    s.connect(("8.8.8.8", 80))
    os.environ['MASTER_ADDR'] = s.getsockname()[0]
    os.environ['MASTER_PORT'] = '29500'
    # 等待所有进程准备好
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)


if __name__ == "__main__":
    size = 3  # 这个就是world_size,也就是进程数,目前最大可调整到5
    mp.spawn(init_processes,  # 主运行函数,初始化进程
             args=(size, run),  # 主函数的参数,绑定run函数
             nprocs=size,  # 当前节点的进程数
             join=True)  # 加入同一进程池

(1)客户端类
# client.py
import os
import sys
import math
import copy
import torch
import numpy as np
import torch.nn.functional
import torch.utils.data
import torch.optim.lr_scheduler as lr_scheduler
sys.path.append('extra_utils')
from extra_utils.model import resnet34, resnet101, mnist_Net
from extra_utils.distributed_utils import init_distributed_mode, cleanup, is_main_process
import torch.distributed as dist
from tqdm import tqdm

# 构建客户端对象
class Client(object):

    def __init__(self, args, conf, train_loader, train_sampler, eval_loader, rank):

        self.conf = conf

        self.rank = rank

        self.train_loader = train_loader

        self.eval_loader = eval_loader

        self.train_sampler = train_sampler

    def local_train(self, model, args, global_epoch, cost_list, train_length, accuracy_list,
                    accuracy_1, accuracy_2, rank, cost_1, cost_2):
        # train_length是方便我们每隔一段时间画出损失函数点

        # for name, param in model.state_dict().items():  # 遍历模型参数
        #     self.local_model.state_dict()[name].copy_(param.clone())  # 将全局模型复制一份到本地训练的模型
        local_model = copy.deepcopy(model)  # 拷贝拷贝可变类型就是完完全全拷贝了一份,是完完全全的两个内存,互不干扰

        # 是否冻结权重  我们默认不冻结
        if args.freeze_layers:
            for name, para in local_model.named_parameters():
                # 除最后的全连接层外,其他权重全部冻结
                if "fc" not in name:
                    para.requires_grad_(False)  # 即只训练全连接层
            pg = [p for p in local_model.parameters() if p.requires_grad]  # 将我们需要训练的各层参数(全连接层)以 列表生成式的方式 生成列表。
            optimizer = torch.optim.Adam(pg, lr=self.conf['lr'])  # 注意此时更新的不是model模型,而是pg列表所对应的模型参数
        else:
            optimizer = torch.optim.Adam(local_model.parameters(), lr=self.conf['lr'])
            # 若冻结权重,则带有BN结构的网络不会被训练,而训练带有BN结构的网络时使用SyncBatchNorm才有意义
            if args.syncBN:
                # 使用SyncBatchNorm后训练会更耗时
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                # 将模型中所有的BN层替换成具有同步功能的BN层

        lf = lambda x: ((1 + math.cos(x * math.pi / self.conf["local_epochs"])) / 2) * (1 - args.lrf) + args.lrf  # cosine
        # 利用lambda函数定义一个输入参数x和函数lf的关系。

        # 使用随机梯度下降算法作为优化算法
        # 需要训练的全连接层参数、初始学习率、动量、正则项
        criterion = torch.nn.CrossEntropyLoss()
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

        local_model.train()  # 进入模型训练模式
        for epoch in range(self.conf["local_epochs"]):

            self.train_sampler.set_epoch(epoch)  # 这种方法使得我们的各个设备在每一轮所获得的数据都是不一样的。
            if is_main_process():
                self.train_loader = tqdm(self.train_loader, file=sys.stdout)

            total_loss = 0
            running_loss = 0
            local_loss = 0

            for batch_id, batch in enumerate(self.train_loader):
                inputs, target = batch
                if rank == 2:
                    for j in range(len(target)):  # 标签翻转攻击
                        if target[j] == 1:
                            target[j] = 4

                optimizer.zero_grad()  # 梯度清零
                output = local_model(inputs)  # 前向传播算法
                loss = criterion(output, target)  # 使用交叉熵计算损失值
                loss.backward()  # 反向传播计算得到了梯度
                total_loss += loss.item()
                running_loss += loss.item()

                optimizer.step()  # 利用反向传播得到的梯度,利用优化算法更新网络参数(权重)
                scheduler.step()  # 更新学习率

                if (batch_id + 1) % (train_length / (self.conf["batch_size"]*5*self.conf["world_size"])) == 0:
                    # 每次本地训练共取出5个点来绘制。
                    cost_list.append(running_loss)
                    # print(running_loss)
                    # 使用命令行参数控制是否打印其他进程图线
                    if not args.only_0:
                        # 用于绘制除rank0之外的其他进程的图线的Loss代码。
                        loss_1 = torch.FloatTensor([running_loss])  # 方便发送接收rank1的准确率
                        loss_2 = torch.FloatTensor([running_loss])  # 方便发送接收rank2的准确率
                        if is_main_process():
                            dist.recv(loss_1, src=1)  # 接收来源于rank1的数据,并将其覆盖于loss_1
                            dist.recv(loss_2, src=2)  # 接收来源于rank2的数据,并将其覆盖于loss_2
                        elif rank == 1:
                            dist.send(loss_1, dst=0)  # 如果是rank1,发送数据acc1
                        else:
                            dist.send(loss_2, dst=0)  # 如果是rank2,发送数据acc2
                        dist.barrier()
                        loss_1 = loss_1.item()  # 取出tensor的data数据。避免最后plt.plot出错如下。
                        loss_2 = loss_2.item()  # VisibleDeprecationWarning: Creating an ndarray from ragged nested...
                        if is_main_process():
                            cost_1.append(loss_1)
                            cost_2.append(loss_2)
                    running_loss = 0

                    # 我们要对所有的进程执行此操作,否则会引起每个进程微小的误差。
                    # 此部分会极大地拉慢程序运行速度。
                    with torch.no_grad():
                        local_model.eval()  # 进入模型评估模式
                        correct = 0
                        dataset_size = 0
                        for batch_id, batch in enumerate(self.eval_loader):
                            inputs, target = batch
                            dataset_size += inputs.size()[0]
                            output = local_model(inputs)
                            local_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
                            pred = output.data.max(1)[1]
                            correct += pred.eq(target.data.view_as(pred)).sum().item()
                        acc = 100.0 * (float(correct) / float(dataset_size))  # 准确率
                        accuracy_list.append(acc)
                        # print(acc)
                        # 使用命令行参数控制是否打印其他进程图线
                        if not args.only_0:
                            # 用于绘制除rank0之外的其他进程的图线的准确率代码。
                            acc1 = torch.FloatTensor([acc])  # 方便发送接收rank1的准确率
                            acc2 = torch.FloatTensor([acc])  # 方便发送接收rank2的准确率
                            if is_main_process():
                                dist.recv(acc1, src=1)  # 接收来源于rank1的数据,并将其覆盖于acc1
                                dist.recv(acc2, src=2)  # 接收来源于rank2的数据,并将其覆盖于acc2
                            elif rank == 1:
                                dist.send(acc1, dst=0)  # 如果是rank1,发送数据acc1
                            else:
                                dist.send(acc2, dst=0)  # 如果是rank2,发送数据acc2
                            dist.barrier()
                            acc1 = acc1.item()  # 取出tensor的data数据。避免最后plt.plot出错如下。
                            acc2 = acc2.item()  # VisibleDeprecationWarning: Creating an ndarray from ragged nested...
                            if is_main_process():
                                accuracy_1.append(acc1)
                                accuracy_2.append(acc2)
                                # print(acc)
                                # print(acc1)
                                # print(acc2)

                    local_loss = 0

            dist.barrier()  # 防止打印进度条的时候,会由于其他进程先训练完成而导致打印出来的字符串混乱掉进度条。

            if is_main_process():
                print("Rank %d, Global_epoch [%d/%d], Local Epoch [%d/%d] loss : %f."
                      % (self.rank, global_epoch + 1, self.conf["global_epochs"], epoch + 1,
                         self.conf["local_epochs"], total_loss))
            # 各进程loss大小差距不大,个人认为应该是因为迁移学习所造成的

        # 取出权重矩阵
        weight_matrix = []
        bias_matrix = []
        for name, parm in local_model.named_parameters():
            if (name == 'module.fc3.weight') | (name == 'fc3.weight'):  # 可以将最后一层全连接层的权重矩阵取出
                weight_matrix = parm.detach()
            elif (name == 'module.fc3.bias') | (name == 'fc3.bias'):
                bias_matrix = parm.detach()
                bias_matrix = bias_matrix.unsqueeze(1)
                if is_main_process():
                    print(bias_matrix.size())
                # tensor.detach()返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,
                # 不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
        mean = torch.ones([weight_matrix.size()[1], 1])
        mean = mean / weight_matrix.size()[1]
        class_mean = weight_matrix.mm(mean)  # 利用矩阵乘法求得 此客户端下 该轮次的 类平均值torch.Size([10, 1])
        class_mean = class_mean + bias_matrix / weight_matrix.size()[1]

        diff = dict()  # 生成一个空的字典
        for name, data in local_model.state_dict().items():  # 遍历更新之后的各层模型参数。并返回每层对应的名字(name)和数据。
            # print(data != model.state_dict()[name])  # 用于打印出来是否参数相等
            diff[name] = (data - model.state_dict()[name])  # 将当前name和全局模型所对应name的数据进行相减,得到权重大小的变化量即权重差
        # print(diff[name])

        return diff, class_mean  # 返回网络参数的变化.value为tensor类型

    # 模型评估
    @torch.no_grad()  # 装饰器的方法实现with.no_grad()
    # 我们模型评估的时候使用全部测试集
    def model_eval(self, model):
        model.eval()  # 进入模型评估模式

        total_loss = 0.0
        correct = 0
        dataset_size = 0
        if is_main_process():
            self.eval_loader = tqdm(self.eval_loader, file=sys.stdout)

        for batch_id, batch in enumerate(self.eval_loader):  # batch_id就为enumerate()遍历集合所返回的批量序号
            inputs, target = batch  # 得到数据集和标签
            dataset_size += inputs.size()[0]  # data.size()=[batch,通道数,32,32]、target.size()=[batch]

            output = model(inputs)

            if self.conf["type"] == "mnist":
                total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            elif self.conf["type"] == "flower":
                total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            elif self.conf["type"] == "cifar":
                total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            else:
                raise TypeError("Not find Appropriate mode.")
                # sum up batch loss
            # .data意即将变量的tensor取出来
            # 因为tensor包含data和grad,分别放置数据和计算的梯度
            pred = output.data.max(1)[1]  # get the index of the max log-probability
            # 按照从左往右的 第一维 取出最大值的索引 torch.max()
            correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        # torch.view_as(tensor)即将调用函数的变量,转变为同参数tensor同样的形状
        # torch.eq()对两个张量tensor进行逐元素比较,如果相等则返回True,否则返回False。True和False作运算时可以作1、0使用
        # .cpu()这一步将预测结果放到cpu上,利用电脑内存存储列表值。从而避免测试过程中爆显存。
        # .sum()是将我们一个批量的预测值求和,便于累加到correct变量中。
        # .item()取出 单元素张量的元素值 并返回该值,保持原元素类型不变。

        acc = 100.0 * (float(correct) / float(dataset_size))  # 准确率
        aver_loss = total_loss / dataset_size  # 平均损失

        return acc, aver_loss

五、总结

总体来说,这位大佬分享的代码非常好,代码注释也十分多,给我十分大的帮助。只是说这个文件目录确实有些头疼。看了半天才看明白。不过也能够看得出大佬的研究历程。在这个项目里真正的体会到了分布式计算,模拟多个节点一起协作的场景。由理论学习转到了实战演示,帮助非常大。
这里面比较头疼的是各种数据的变换和多进程之间的协作,需要一步一步的去理解,有些头疼。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,776评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,527评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,361评论 0 350
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,430评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,511评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,544评论 1 293
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,561评论 3 414
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,315评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,763评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,070评论 2 330
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,235评论 1 343
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,911评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,554评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,173评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,424评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,106评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,103评论 2 352

推荐阅读更多精彩内容