-
Notifications
You must be signed in to change notification settings - Fork 1
/
server.py
131 lines (106 loc) · 3.95 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from dotenv import load_dotenv
from clients.azure_blob import AzureBlob
from clients.gcp_blob import GCPBlob
from clients.utils import blob_pre_model_name, server_local_pre_model_name, blob_post_model_name, server_local_post_model_name
class NN(nn.Module):
def __init__(self, input_features=11, layer1=20, layer2=20, out_features=2):
"""Initialize the model for loan prediction"""
super().__init__()
self.fc1 = nn.Linear(input_features, layer1)
self.fc2 = nn.Linear(layer1, layer2)
self.out = nn.Linear(layer2, out_features)
def forward(self, x):
"""Forward pass with 11 input features"""
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.out(x)
return x
load_dotenv()
epoch = 0
EPOCH_COUNT = 10
federated_list = [
{
"company_a": {
"cloud": AzureBlob,
"container": "loans-a",
"env_key": "something"
}},
{
"company_b": {
"cloud": GCPBlob,
"container": "loans-b",
"env_key": "something"
}},
# {
# "company_c": {
# "cloud": AzureBlob,
# "container": "loans-c",
# "env_key": "something"
# }},
# {
# "company_d": {
# "cloud": GCPBlob,
# "container": "loans-d",
# "env_key": "something"
# }},
]
def publish_state_dict(model_state_dict, epoch):
""" Saves model locally then pushes to all databases """
torch.save(model_state_dict, server_local_pre_model_name(epoch))
try:
os.remove(server_local_pre_model_name(epoch-1))
except:
pass
for entity in federated_list:
cloud_helper = list(entity.values())[0]['cloud']()
container = list(entity.values())[0]['container']
cloud_helper.upload_to_blob_storage(server_local_pre_model_name(epoch), container, blob_pre_model_name(epoch), delete_blob_name=blob_pre_model_name(epoch-1))
def poll_clients(epoch, remaining):
""" Looking for post_models from clients """
for entity in remaining:
cloud_helper = list(entity.values())[0]['cloud']()
container = list(entity.values())[0]['container']
# downloads post_model_0
if (cloud_helper.check_for_file(container, blob_post_model_name(epoch))):
cloud_helper.download_from_blob_storage(server_local_post_model_name(epoch, container), container, blob_post_model_name(epoch), delete_local_name=server_local_post_model_name(epoch-1, container))
remaining.remove(entity)
return remaining
def federated_averaging(epoch):
result = None
for entity in federated_list:
container = list(entity.values())[0]['container']
if not result: # first entity
result = torch.load(server_local_post_model_name(epoch, container))
continue
else: # other entities
state_dict = torch.load(server_local_post_model_name(epoch, container))
# adding all state_dict
for param in result:
result[param] = result[param] + state_dict[param]
# divides by number of containers at the end
for param in result:
result[param] = result[param] / len(federated_list)
return result
remaining = federated_list.copy()
torch.manual_seed(0)
model = NN()
publish_state_dict(model.state_dict(), epoch) # saves locally and then pushing to all client containers
# pre_model_0 (pushed)
while epoch < EPOCH_COUNT:
# checking if post_model_0 in all containers
remaining = poll_clients(epoch, remaining)
# when no more entities remain
if not remaining:
updated_model = federated_averaging(epoch)
epoch += 1
# pushing pre_model_1
publish_state_dict(updated_model, epoch)
remaining = federated_list.copy()
else:
print('ab to sleep')
time.sleep(5)