-
Notifications
You must be signed in to change notification settings - Fork 8
/
server.py
104 lines (89 loc) · 3.03 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# -*- coding:utf-8 -*-
"""
@Time: 2022/03/02 11:22
@Author: KI
@File: server.py
@Motto: Hungry And Humble
"""
import copy
import random
import sys
import numpy as np
import torch
from tqdm import tqdm
sys.path.append('../')
from model import ANN
from get_data import device, clients_wind
from client import train, test
# Implementation for Scaffold server.
class Scaffold:
def __init__(self, options):
self.C = options['C']
self.E = options['E']
self.B = options['B']
self.K = options['K']
self.r = options['r']
self.input_dim = options['input_dim']
self.lr = options['lr']
self.clients = options['clients']
self.nn = ANN(input_dim=self.input_dim, name='server', B=self.B, E=self.E, lr=self.lr).to(
device)
for k, v in self.nn.named_parameters():
self.nn.control[k] = torch.zeros_like(v.data)
self.nn.delta_control[k] = torch.zeros_like(v.data)
self.nn.delta_y[k] = torch.zeros_like(v.data)
self.nns = []
for i in range(self.K):
temp = copy.deepcopy(self.nn)
temp.name = self.clients[i]
temp.control = copy.deepcopy(self.nn.control) # ci
temp.delta_control = copy.deepcopy(self.nn.delta_control) # ci
temp.delta_y = copy.deepcopy(self.nn.delta_y)
self.nns.append(temp)
def server(self):
for t in tqdm(range(self.r)):
print('round', t + 1, ':')
# sampling
m = np.max([int(self.C * self.K), 1])
index = random.sample(range(0, self.K), m)
# dispatch
self.dispatch(index)
# local updating
self.client_update(index)
# aggregation
self.aggregation(index)
return self.nn
def aggregation(self, index):
s = 0.0
for j in index:
# normal
s += self.nns[j].len
# compute
x = {}
c = {}
# init
for k, v in self.nns[0].named_parameters():
x[k] = torch.zeros_like(v.data)
c[k] = torch.zeros_like(v.data)
for j in index:
for k, v in self.nns[j].named_parameters():
x[k] += self.nns[j].delta_y[k] / len(index) # averaging
c[k] += self.nns[j].delta_control[k] / len(index) # averaging
# update x and c
for k, v in self.nn.named_parameters():
v.data += x[k].data # lr=1
self.nn.control[k].data += c[k].data * (len(index) / self.K)
def dispatch(self, index):
for j in index:
for old_params, new_params in zip(self.nns[j].parameters(), self.nn.parameters()):
old_params.data = new_params.data.clone()
def client_update(self, index): # update nn
for k in index:
self.nns[k] = train(self.nns[k], self.nn)
def global_test(self):
model = self.nn
model.eval()
c = clients_wind
for client in c:
model.name = client
test(model)