-
Notifications
You must be signed in to change notification settings - Fork 17
/
AlphaZeroMCTS.py
46 lines (36 loc) · 1.43 KB
/
AlphaZeroMCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-
from MCTS import MCTS
import numpy as np
def softmax(x):
'''avoid data overflow'''
probs = np.exp(x - np.max(x))
probs /= np.sum(probs)
return probs
class AlphaZeroMCTS(MCTS):
def __init__(self, policy_value_fn=None, nplays=1000, cpuct=5, epsilon=0, alpha=0.3, is_selfplay=False):
MCTS.__init__(self, nplays, cpuct, epsilon, alpha, is_selfplay=is_selfplay)
self._policy_value_fn = policy_value_fn
def _evaluate(self, state):
action_probs, leaf_value = self._policy_value_fn(state)
# Check for end of game, Adjust the leaf_value
# if end, then policy evaluation
is_end, winner = state.game_end()
if is_end:
if winner == -1: # tie
leaf_value = 0.0
else:
leaf_value = 1.0 if winner == state.get_current_player() else -1.0
return is_end, action_probs, leaf_value
def _play(self, temp=1e-3):
'''
calc the move probabilities based on the visit counts at the root node
temp: temperature parameter
'''
act_visits = [(act, node._n_visits) for act, node in self._root._children.items()]
acts, visits = zip(*act_visits)
pi = softmax(1.0 / temp * np.log(np.array(visits) + 1e-10))
# pi = np.power(visits, 1/temp)
# pi = pi / np.sum(pi * 1.0)
return acts, pi
def __str__(self):
return "AlphaZeroMCTS"