-
Notifications
You must be signed in to change notification settings - Fork 2
/
deep_q_network.py
54 lines (41 loc) · 1.7 KB
/
deep_q_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
class DeepQNetwork(nn.Module):
def __init__(self, lr, n_actions, name, input_dims, chkpt_dir):
super(DeepQNetwork, self).__init__()
self.checkpoint_dir = chkpt_dir
self.checkpoint_file = os.path.join(self.checkpoint_dir, name)
self.conv1 = nn.Conv2d(input_dims[0], 32, 7, stride=4)
self.conv2 = nn.Conv2d(32, 64, 5, stride=2)
self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
fc_input_dims = self.calculate_conv_output_dims(input_dims)
self.fc1 = nn.Linear(fc_input_dims, 512)
self.fc2 = nn.Linear(512, n_actions)
self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
self.loss = nn.MSELoss()
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
self.to(self.device)
def calculate_conv_output_dims(self, input_dims):
state = T.zeros(1, *input_dims)
dims = self.conv1(state)
dims = self.conv2(dims)
dims = self.conv3(dims)
return int(np.prod(dims.size()))
def forward(self, state):
conv1 = F.relu(self.conv1(state))
conv2 = F.relu(self.conv2(conv1))
conv3 = F.relu(self.conv3(conv2))
conv_state = conv3.view(conv3.size()[0], -1)
flat1 = F.relu(self.fc1(conv_state))
actions = self.fc2(flat1)
return actions
def save_checkpoint(self):
print('... saving checkpoint ...')
T.save(self.state_dict(), self.checkpoint_file)
def load_checkpoint(self):
print('... loading checkpoint ...')
self.load_state_dict(T.load(self.checkpoint_file))