forked from junxiaosong/AlphaZero_Gomoku
-
Notifications
You must be signed in to change notification settings - Fork 0
/
policy_value_net_numpy.py
115 lines (92 loc) · 3.94 KB
/
policy_value_net_numpy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
"""
Implement the policy value network using numpy, so that we can play with the
trained AI model without installing any DL framwork
@author: Junxiao Song
"""
from __future__ import print_function
import numpy as np
# some utility functions
def softmax(x):
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
def relu(X):
out = np.maximum(X, 0)
return out
def conv_forward(X, W, b, stride=1, padding=1):
n_filters, d_filter, h_filter, w_filter = W.shape
# theano conv2d flips the filters (rotate 180 degree) first
# while doing the calculation
W = W[:, :, ::-1, ::-1]
n_x, d_x, h_x, w_x = X.shape
h_out = (h_x - h_filter + 2 * padding) / stride + 1
w_out = (w_x - w_filter + 2 * padding) / stride + 1
h_out, w_out = int(h_out), int(w_out)
X_col = im2col_indices(X, h_filter, w_filter,
padding=padding, stride=stride)
W_col = W.reshape(n_filters, -1)
out = (np.dot(W_col, X_col).T + b).T
out = out.reshape(n_filters, h_out, w_out, n_x)
out = out.transpose(3, 0, 1, 2)
return out
def fc_forward(X, W, b):
out = np.dot(X, W) + b
return out
def get_im2col_indices(x_shape, field_height,
field_width, padding=1, stride=1):
# First figure out what the size of the output should be
N, C, H, W = x_shape
assert (H + 2 * padding - field_height) % stride == 0
assert (W + 2 * padding - field_height) % stride == 0
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)
i0 = np.repeat(np.arange(field_height), field_width)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height * C)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
return (k.astype(int), i.astype(int), j.astype(int))
def im2col_indices(x, field_height, field_width, padding=1, stride=1):
""" An implementation of im2col based on some fancy indexing """
# Zero-pad the input
p = padding
x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
k, i, j = get_im2col_indices(x.shape, field_height,
field_width, padding, stride)
cols = x_padded[:, k, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
return cols
class PolicyValueNetNumpy():
"""policy-value network in numpy """
def __init__(self, board_width, board_height, net_params):
self.board_width = board_width
self.board_height = board_height
self.params = net_params
def policy_value_fn(self, board):
"""
input: board
output: a list of (action, probability) tuples for each available
action and the score of the board state
"""
legal_positions = board.availables
current_state = board.current_state()
X = current_state.reshape(-1, 4, self.board_width, self.board_height)
# first 3 conv layers with ReLu nonlinearity
for i in [0, 2, 4]:
X = relu(conv_forward(X, self.params[i], self.params[i+1]))
# policy head
X_p = relu(conv_forward(X, self.params[6], self.params[7], padding=0))
X_p = fc_forward(X_p.flatten(), self.params[8], self.params[9])
act_probs = softmax(X_p)
# value head
X_v = relu(conv_forward(X, self.params[10],
self.params[11], padding=0))
X_v = relu(fc_forward(X_v.flatten(), self.params[12], self.params[13]))
value = np.tanh(fc_forward(X_v, self.params[14], self.params[15]))[0]
act_probs = zip(legal_positions, act_probs.flatten()[legal_positions])
return act_probs, value