-
Notifications
You must be signed in to change notification settings - Fork 1
/
helper_classes.py
64 lines (45 loc) · 1.9 KB
/
helper_classes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch.nn as nn
import torch
import torch.nn.functional as F
class StateEncoding(nn.Module):
def __init__(self, action_space, perf_space, output_layer):
super(StateEncoding, self).__init__()
self.action_encoding_layer = nn.Linear(action_space, output_layer)
self.perf_encoding_layer = nn.Linear(perf_space, output_layer)
def forward(self, action, perf):
action_output = self.action_encoding_layer(action)
perf_output = self.perf_encoding_layer(perf)
out = action_output + perf_output
out = torch.tanh(out)
return out
class IntermediateClassifier(nn.Module):
def __init__(self, prev_layers, type_last_layer, classification_layer):
super(IntermediateClassifier, self).__init__()
self.all_layers = nn.ModuleList()
self.output_layer = nn.ModuleList()
for layer in prev_layers:
self.all_layers.append(layer)
self.type_last_layer = type_last_layer
self.output_layer.append(classification_layer)
def forward(self, x):
batch_size = x.size(0)
for layer in self.all_layers:
classname = layer.__class__.__name__
if classname.find("Linear") != -1 and len(x.size()) > 2:
x = x.view(batch_size, x.size(1) * x.size(2)* x.size(3))
x = layer(x)
output_layer = x
if self.type_last_layer != "FCL":
output_layer = x.view(batch_size, x.size(1) * x.size(2)* x.size(3))
for layer in self.output_layer:
output = layer(output_layer)
return output
class CustomLoss(nn.Module):
def __init__(self, actor, critic):
super(CustomLoss, self).__init__()
self.actor = actor
self.critic = critic
def forward(self, state):
loss = torch.mul(-1, self.critic(state))
loss = loss.mean()
return loss