PyTorch实现FedAvg算法的过程是什么,有哪些要点
Admin 2022-09-09 群英技术资讯 263 次浏览
在之前的一篇博客联邦学习基本算法FedAvg的代码实现中利用numpy手搭神经网络实现了FedAvg,手搭的神经网络效果已经很好了,不过这还是属于自己造轮子,建议优先使用PyTorch来实现。
联邦学习中存在多个客户端,每个客户端都有自己的数据集,这个数据集他们是不愿意共享的。
本文选用的数据集为中国北方某城市十个区/县从2016年到2019年三年的真实用电负荷数据,采集时间间隔为1小时,即每一天都有24个负荷值。
我们假设这10个地区的电力部门不愿意共享自己的数据,但是他们又想得到一个由所有数据统一训练得到的全局模型。
除了电力负荷数据以外,还有一个备选数据集:风功率数据集。两个数据集通过参数type指定:type == 'load’表示负荷数据,'wind’表示风功率数据。
用某一时刻前24个时刻的负荷值以及该时刻的相关气象数据(如温度、湿度、压强等)来预测该时刻的负荷值。
对于风功率数据,同样使用某一时刻前24个时刻的风功率值以及该时刻的相关气象数据来预测该时刻的风功率值。
各个地区应该就如何制定特征集达成一致意见,本文使用的各个地区上的数据的特征是一致的,可以直接使用。
原始论文中提出的FedAvg的框架为:
客户端模型采用PyTorch搭建:
class ANN(nn.Module): def __init__(self, input_dim, name, B, E, type, lr): super(ANN, self).__init__() self.name = name self.B = B self.E = E self.len = 0 self.type = type self.lr = lr self.loss = 0 self.fc1 = nn.Linear(input_dim, 20) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() self.dropout = nn.Dropout() self.fc2 = nn.Linear(20, 20) self.fc3 = nn.Linear(20, 20) self.fc4 = nn.Linear(20, 1) def forward(self, data): x = self.fc1(data) x = self.sigmoid(x) x = self.fc2(x) x = self.sigmoid(x) x = self.fc3(x) x = self.sigmoid(x) x = self.fc4(x) x = self.sigmoid(x) return x
服务器端执行以下步骤:
简单来说,每一轮通信时都只是选择部分客户端,这些客户端利用本地的数据进行参数更新,然后将更新后的参数传给服务器,服务器汇总客户端更新后的参数形成最新的全局参数。下一轮通信时,服务器端将最新的参数分发给被选中的客户端,进行下一轮更新。
客户端没什么可说的,就是利用本地数据对神经网络模型的参数进行更新。
class FedAvg: def __init__(self, options): self.C = options['C'] self.E = options['E'] self.B = options['B'] self.K = options['K'] self.r = options['r'] self.input_dim = options['input_dim'] self.type = options['type'] self.lr = options['lr'] self.clients = options['clients'] self.nn = ANN(input_dim=self.input_dim, name='server', B=B, E=E, type=self.type, lr=self.lr).to(device) self.nns = [] for i in range(K): temp = copy.deepcopy(self.nn) temp.name = self.clients[i] self.nns.append(temp)
参数:
服务器端代码如下:
def server(self): for t in range(self.r): print('第', t + 1, '轮通信:') m = np.max([int(self.C * self.K), 1]) # sampling index = random.sample(range(0, self.K), m) # dispatch self.dispatch(index) # local updating self.client_update(index) # aggregation self.aggregation(index) # return global model return self.nn
其中client_update(index):
def client_update(self, index): # update nn for k in index: self.nns[k] = train(self.nns[k])
aggregation(index):
def aggregation(self, index): s = 0 for j in index: # normal s += self.nns[j].len params = {} with torch.no_grad(): for k, v in self.nns[0].named_parameters(): params[k] = copy.deepcopy(v) params[k].zero_() for j in index: with torch.no_grad(): for k, v in self.nns[j].named_parameters(): params[k] += v * (self.nns[j].len / s) with torch.no_grad(): for k, v in self.nn.named_parameters(): v.copy_(params[k])
dispatch(index):
def dispatch(self, index): params = {} with torch.no_grad(): for k, v in self.nn.named_parameters(): params[k] = copy.deepcopy(v) for j in index: with torch.no_grad(): for k, v in self.nns[j].named_parameters(): v.copy_(params[k])
下面对重要代码进行分析:
客户端的选择
m = np.max([int(self.C * self.K), 1]) index = random.sample(range(0, self.K), m)
index中存储中m个0~10间的整数,表示被选中客户端的序号。
客户端的更新
for k in index: self.client_update(self.nns[k])
服务器端汇总客户端模型的参数
关于模型汇总方式,可以参考一下我的另一篇文章:对FedAvg中模型聚合过程的理解。
当然,这只是一种很简单的汇总方式,还有一些其他类型的汇总方式。
论文Electricity Consumer Characteristics Identification: A Federated Learning Approach中总结了三种汇总方式:
normal:原始论文中的方式,即根据样本数量来决定客户端参数在最终组合时所占比例。
LA:根据客户端模型的损失占所有客户端损失和的比重来决定最终组合时参数所占比例。
LS:根据损失与样本数量的乘积所占的比重来决定。 将更新后的参数分发给被选中的客户端
def dispatch(self, index): params = {} with torch.no_grad(): for k, v in self.nn.named_parameters(): params[k] = copy.deepcopy(v) for j in index: with torch.no_grad(): for k, v in self.nns[j].named_parameters(): v.copy_(params[k])
客户端只需要利用本地数据来进行更新就行了:
def client_update(self, index): # update nn for k in index: self.nns[k] = train(self.nns[k])
其中train():
def train(ann): ann.train() # print(p) if ann.type == 'load': Dtr, Dte = nn_seq(ann.name, ann.B, ann.type) else: Dtr, Dte = nn_seq_wind(ann.named, ann.B, ann.type) ann.len = len(Dtr) # print(len(Dtr)) loss_function = nn.MSELoss().to(device) loss = 0 optimizer = torch.optim.Adam(ann.parameters(), lr=ann.lr) for epoch in range(ann.E): cnt = 0 for (seq, label) in Dtr: cnt += 1 seq = seq.to(device) label = label.to(device) y_pred = ann(seq) loss = loss_function(y_pred, label) optimizer.zero_grad() loss.backward() optimizer.step() print('epoch', epoch, ':', loss.item()) return ann
def global_test(self): model = self.nn model.eval() c = clients if self.type == 'load' else clients_wind for client in c: model.name = client test(model)
本次实验的参数选择为:
K | C | E | B | r |
---|---|---|---|---|
10 | 0.5 | 50 | 50 | 5 |
if __name__ == '__main__': K, C, E, B, r = 10, 0.5, 50, 50, 5 type = 'load' input_dim = 30 if type == 'load' else 28 _client = clients if type == 'load' else clients_wind lr = 0.08 options = {'K': K, 'C': C, 'E': E, 'B': B, 'r': r, 'type': type, 'clients': _client, 'input_dim': input_dim, 'lr': lr} fedavg = FedAvg(options) fedavg.server() fedavg.global_test()
各个客户端单独训练(训练50轮,batch大小为50)后在本地的测试集上的表现为:
客户端编号 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
MAPE / % | 5.33 | 4.11 | 3.03 | 4.20 | 3.02 | 2.70 | 2.94 | 2.99 | 2.30 | 4.10 |
可以看到,由于各个客户端的数据都十分充足,所以每个客户端自己训练的本地模型的预测精度已经很高了。
服务器与客户端通信5轮后,服务器上的全局模型在10个客户端测试集上的表现如下所示:
客户端编号 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
MAPE / % | 6.84 | 4.54 | 3.56 | 5.11 | 3.75 | 4.47 | 4.30 | 3.90 | 3.15 | 4.58 |
可以看到,经过联邦学习框架得到全局模型在各个客户端上表现同样很好ÿ0c;这是因为十个地区上的数据分布类似。
给出numpy和PyTorch的对比:
客户端编号 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
本地 | 5.33 | 4.11 | 3.03 | 4.20 | 3.02 | 2.70 | 2.94 | 2.99 | 2.30 | 4.10 |
numpy | 6.58 | 4.19 | 3.17 | 5.13 | 3.58 | 4.69 | 4.71 | 3.75 | 2.94 | 4.77 |
PyTorch | 6.84 | 4.54 | 3.56 | 5.11 | 3.75 | 4.47 | 4.30 | 3.90 | 3.15 | 4.58 |
同样本地模型的效果是最好的,PyTorch搭建的网络和numpy搭建的网络效果差不多,但推荐使用PyTorch,不要造轮子。
我把数据和代码放在了GitHub上:源码及数据,原创不易,下载时请随手给个follow和star,感谢!
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:mmqy2019@163.com进行举报,并提供相关证据,查实之后,将立刻删除涉嫌侵权内容。
猜你喜欢
我们需要使用到图片素材的场景很多,但是很多素材都有水印,而一张张去除水印是工作量大。对此,这篇文章小编就给大家分享如何用python实现图片批量去水印的方法,下面我们一起来看看是怎样做的吧。
这篇文章主要介绍了Python利用matplotlib画出漂亮的分析图表,文章首先引入数据集展开详情,需要的朋友可以参考一下
这篇文章主要介绍Matplotlib绘制子图的方式,常用的方式有通过plt的subplot、通过figure的add_subplot和通过plt的subplots,下面我们就来看看怎样绘制子图吧,感兴趣的朋友可以参考。
深度学习已经成为机器学习中最受欢迎和发展最快的领域。深度学习的常见应用包括语音识别、图像识别、自然语言处理、推荐系统等等。本文将通过一些示例代码,带你详细了解深入学习
这篇文章主要为大家详细介绍了如何利用Python语言绘制好看的数据动态图,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起动手尝试一下
成为群英会员,开启智能安全云计算之旅
立即注册Copyright © QY Network Company Ltd. All Rights Reserved. 2003-2020 群英 版权所有
增值电信经营许可证 : B1.B2-20140078 粤ICP备09006778号 域名注册商资质 粤 D3.1-20240008