Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add Go env, AlphaZero ctree and league training #65

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
1d43e5e
feature(pu): add init league version of alphazero config and collecto…
puyuan1996 May 3, 2023
473cfd6
sync code
puyuan1996 May 3, 2023
7355fc8
fix(pu): fix eval_episode_return in self_play_mode
puyuan1996 May 4, 2023
9502336
feature(pu): add rule_bot as the fist init historical_player in league
puyuan1996 May 10, 2023
0897677
fix(pu): fix gomoku_rule_bot_v0
puyuan1996 May 11, 2023
43e3abf
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 May 16, 2023
8d6490a
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 May 16, 2023
71d6eac
feature(pu): add snapshot_the_player_in_iter_zero and one_phase_step …
puyuan1996 May 16, 2023
52b5d8c
fix(pu): only use the main_player transitions to tran main_player, po…
puyuan1996 May 17, 2023
31376c6
fix(pu): fix save transitions in battle_alphazero_collector
puyuan1996 May 17, 2023
0d38129
fix(pu): fix new_data bug
puyuan1996 May 17, 2023
cc3be43
polish(pu): polish league config
puyuan1996 May 18, 2023
aed4f30
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 May 18, 2023
36ecf62
polish(pu): polish gomoku_alphazero_league_config
puyuan1996 May 24, 2023
5dfa97c
polish(pu): polish gomoku eval render
puyuan1996 May 30, 2023
73c18fd
feature(pu): add init version of go env and related alphazero sp config
puyuan1996 Jun 22, 2023
aa41a55
feature(pu): add init version of go muzero config
puyuan1996 Jun 24, 2023
e10e849
poliah(pu): polish go alphazero/muzero configs
puyuan1996 Jun 25, 2023
b0509d4
feature(pu): add eval_go_mcts_bot_speed_win-rate.py, polish test_go_m…
puyuan1996 Jun 28, 2023
689e1c7
feature(pu): add init versio of katago pure policy based bot (without…
puyuan1996 Jun 30, 2023
34025e7
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 Jul 3, 2023
c9359a3
Merge branch 'dev-league' into dev-go-league
puyuan1996 Jul 3, 2023
51df0d5
polish(pu): polish katago_policy, polish test_katago_bot, fix go_env …
puyuan1996 Jul 5, 2023
5b68ba9
fix(pu): fix katago_policy cuda bug
puyuan1996 Jul 5, 2023
b975913
polish(pu): polish katago_policy init
puyuan1996 Jul 6, 2023
2349e28
polish katago python import
puyuan1996 Jul 6, 2023
2e62926
Merge branch 'dev-go' into dev-go-league
puyuan1996 Jul 6, 2023
d6083e4
feature(pu): add alphazero ctree init version
puyuan1996 Jul 20, 2023
0f4803c
feature(pu): add init version of alphazero ctree
puyuan1996 Jul 21, 2023
6ef40bd
polish(pu): ignore build directory
puyuan1996 Jul 21, 2023
db6cca1
feature(pu): add go_bot_policy_v0 and ao_alphazero_league_config, pol…
puyuan1996 Jul 25, 2023
5372979
fix(pu): fix simulate_env_copy.battle_mode and polish softmax in mcts…
puyuan1996 Jul 26, 2023
aec6805
fix(pu): add reset_katago_game_state() method in go_env reset()
puyuan1996 Jul 26, 2023
bfeeaf5
polish(pu): polish visit_count_to_action_dist method, polish reset_ka…
puyuan1996 Aug 2, 2023
d965be2
sync code
puyuan1996 Aug 2, 2023
3adc32b
polish(pu): optimize the get_next_action method by replacing the deep…
puyuan1996 Aug 4, 2023
a696efd
sync code
puyuan1996 Aug 7, 2023
5138dfa
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
puyuan1996 Aug 11, 2023
6bccaeb
fix(pu): use would_be_suicide check in katago legal_action, fix Katag…
puyuan1996 Aug 12, 2023
bca67bc
feature(pu): add get_katago_statistics_for_dataset method
puyuan1996 Aug 13, 2023
475df17
polish(pu): polish get_katago_statistics_for_dataset
puyuan1996 Aug 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1444,4 +1444,8 @@ events.*
!/lzero/mcts/**/lib/*.h
**/tb/*
**/mcts/ctree/tests_cpp/*
**/*tmp*
build/

**/*tmp*

/lzero/mcts/**/pybind11/*
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "lzero/mcts/ctree/ctree_alphazero/pybind11"]
path = lzero/mcts/ctree/ctree_alphazero/pybind11
url = https://github.com/pybind/pybind11.git
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .train_alphazero import train_alphazero
from .train_alphazero_league import train_alphazero_league
from .eval_alphazero import eval_alphazero
from .train_muzero import train_muzero
from .eval_muzero import eval_muzero
Expand Down
6 changes: 6 additions & 0 deletions lzero/entry/eval_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from lzero.worker import AlphaZeroEvaluator
from zoo.board_games.go.envs.katago_policy import KatagoPolicy


def eval_alphazero(
Expand Down Expand Up @@ -47,6 +48,11 @@ def eval_alphazero(
cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)

if cfg.env.use_katago_bot:
cfg.env.katago_policy = KatagoPolicy(checkpoint_path=cfg.env.katago_checkpoint_path, board_size=cfg.env.board_size,
ignore_pass_if_have_other_legal_actions=cfg.env.ignore_pass_if_have_other_legal_actions, device=cfg.policy.device)

# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
Expand Down
7 changes: 7 additions & 0 deletions lzero/entry/train_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from lzero.policy import visit_count_temperature
from lzero.worker import AlphaZeroCollector, AlphaZeroEvaluator
from zoo.board_games.go.envs.katago_policy import KatagoPolicy


def train_alphazero(
Expand Down Expand Up @@ -51,8 +52,14 @@ def train_alphazero(
cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)

if cfg.env.use_katago_bot:
cfg.env.katago_policy = KatagoPolicy(checkpoint_path=cfg.env.katago_checkpoint_path, board_size=cfg.env.board_size,
ignore_pass_if_have_other_legal_actions=cfg.env.ignore_pass_if_have_other_legal_actions, device=cfg.policy.device)

# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)

collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed)
Expand Down
317 changes: 317 additions & 0 deletions lzero/entry/train_alphazero_league.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
import copy
import logging
import os
import shutil
from functools import partial
from typing import Optional

import torch
from ding.config import compile_config
from ding.envs import SyncSubprocessEnvManager
from ding.league import BaseLeague, ActivePlayer
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner, create_buffer
from ding.worker import NaiveReplayBuffer
from easydict import EasyDict
from tensorboardX import SummaryWriter

from lzero.policy.alphazero import AlphaZeroPolicy
from lzero.worker import AlphaZeroEvaluator
from lzero.worker import BattleAlphaZeroCollector
from lzero.policy import visit_count_temperature
from zoo.board_games.go.envs.katago_policy import KatagoPolicy


def win_loss_draw(episode_info):
"""
Overview:
Get win/loss/draw result from episode info
Arguments:
- episode_info (:obj:`list`): List of episode info
Returns:
- win_loss_result (:obj:`list`): List of win/loss/draw result
Examples:
>>> episode_info = [{'eval_episode_return': 1}, {'eval_episode_return': 0}, {'eval_episode_return': -1}]
>>> win_loss_draw(episode_info)
['wins', 'draws', 'losses']
"""
win_loss_result = []
for e in episode_info:
if e['eval_episode_return'] == 1:
result = 'wins'
elif e['eval_episode_return'] == 0:
result = 'draws'
else:
result = 'losses'
win_loss_result.append(result)

return win_loss_result


class AlphaZeroLeague(BaseLeague):
# override
def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
assert isinstance(player, ActivePlayer), player.__class__
player_job_info = EasyDict(player.get_job(eval_flag))
return {
'agent_num': 2,
# home player_id
'launch_player': player.player_id,
# include home and away player_id
'player_id': [player.player_id, player_job_info.opponent.player_id],
'checkpoint_path': [player.checkpoint_path, player_job_info.opponent.checkpoint_path],
'player_active_flag': [isinstance(p, ActivePlayer) for p in [player, player_job_info.opponent]],
}

# override
def _mutate_player(self, player: ActivePlayer):
# no mutate operation
pass

# override
def _update_player(self, player: ActivePlayer, player_info: dict) -> None:
assert isinstance(player, ActivePlayer)
if 'learner_step' in player_info:
player.total_agent_step = player_info['learner_step']
# torch.save(player_info['state_dict'], player.checkpoint_path)

# override
@staticmethod
def save_checkpoint(src_checkpoint_path: str, dst_checkpoint_path: str) -> None:
shutil.copy(src_checkpoint_path, dst_checkpoint_path)


def train_alphazero_league(cfg, Env, seed=0, max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10)) -> None:
"""
Overview:
Train alphazero league
Arguments:
- cfg (:obj:`EasyDict`): Config dict
- Env (:obj:`BaseEnv`): Env class
- seed (:obj:`int`): Random seed
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- None
"""
if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
else:
cfg.policy.device = 'cpu'

# prepare config
cfg = compile_config(
cfg,
SyncSubprocessEnvManager,
AlphaZeroPolicy,
BaseLearner,
BattleAlphaZeroCollector,
AlphaZeroEvaluator,
NaiveReplayBuffer,
save_cfg=True
)

if cfg.env.use_katago_bot:
# for eval
cfg.env.katago_policy = KatagoPolicy(checkpoint_path=cfg.env.katago_checkpoint_path, board_size=cfg.env.board_size,
ignore_pass_if_have_other_legal_actions=cfg.env.ignore_pass_if_have_other_legal_actions, device=cfg.policy.device)
# for collect
cfg.policy.collect.katago_policy = KatagoPolicy(checkpoint_path=cfg.env.katago_checkpoint_path, board_size=cfg.env.board_size,
ignore_pass_if_have_other_legal_actions=cfg.env.ignore_pass_if_have_other_legal_actions, device=cfg.policy.device)

collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
collector_env_cfg = copy.deepcopy(cfg.env)
evaluator_env_cfg = copy.deepcopy(cfg.env)
evaluator_env_cfg.battle_mode = 'eval_mode'
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))

# TODO(pu): use different replay buffer for different players
# create replay buffer
replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)

# create league
league = AlphaZeroLeague(cfg.policy.league)
policies, learners, collectors = {}, {}, {}

# create players
for player_id in league.active_players_ids:
policy = create_policy(cfg.policy, enable_field=['learn', 'collect', 'eval'])
policies[player_id] = policy
collector_env = SyncSubprocessEnvManager(
env_fn=[partial(Env, collector_env_cfg) for _ in range(collector_env_num)], cfg=cfg.env.manager
)
collector_env.seed(seed)

learners[player_id] = BaseLearner(
cfg.policy.learn.learner,
policy.learn_mode,
tb_logger,
exp_name=cfg.exp_name,
instance_name=player_id + '_learner'
)
collectors[player_id] = BattleAlphaZeroCollector(
cfg.policy.collect.collector,
collector_env, [policy.collect_mode, policy.collect_mode],
tb_logger,
exp_name=cfg.exp_name,
instance_name=player_id + '_collector'
)

# create policy
policy = create_policy(cfg.policy, enable_field=['learn', 'collect', 'eval'])
main_key = [k for k in learners.keys() if k.startswith('main_player')][0]
main_player = league.get_player_by_id(main_key)
main_learner = learners[main_key]
main_collector = collectors[main_key]

policies['historical'] = policy

# create bot policy
cfg.policy.type = cfg.policy.league.player_category[0] + '_bot_policy_v0'
bot_policy = create_policy(cfg.policy, enable_field=['learn', 'collect', 'eval'])
policies['bot'] = bot_policy

# create evaluator
evaluator_env = SyncSubprocessEnvManager(
env_fn=[partial(Env, evaluator_env_cfg) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
)
evaluator_env.seed(seed, dynamic_seed=False)
evaluator_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
evaluator_cfg.stop_value = cfg.env.stop_value
evaluator = AlphaZeroEvaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
instance_name='vs_bot_evaluator'
)

def load_checkpoint_fn(player_id: str, ckpt_path: str):
state_dict = torch.load(ckpt_path)
policies[player_id].learn_mode.load_state_dict(state_dict)

league.load_checkpoint = load_checkpoint_fn

if cfg.policy.league.snapshot_the_player_in_iter_zero:
# snapshot the initial player as the first historical player
for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
torch.save(policies[player_id].collect_mode.state_dict(), player_ckpt_path)
league.judge_snapshot(player_id, force=True)

set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
league_iter = 0
while True:
if evaluator.should_eval(main_learner.train_iter):
stop_flag, eval_episode_info = evaluator.eval(
main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
)
win_loss_result = win_loss_draw(eval_episode_info)

# set eval bot rating as 100.
main_player.rating = league.metric_env.rate_1vsC(
main_player.rating, league.metric_env.create_rating(mu=100, sigma=1e-8), win_loss_result
)
if stop_flag:
break

for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
tb_logger.add_scalar(
'league/{}_trueskill'.format(player_id),
league.get_player_by_id(player_id).rating.exposure, main_collector.envstep
)
collector, learner = collectors[player_id], learners[player_id]

job = league.get_job_info(player_id)
opponent_player_id = job['player_id'][1]
# print('job player: {}'.format(job['player_id']))
if 'historical' in opponent_player_id and 'bot' not in opponent_player_id:
opponent_policy = policies['historical'].collect_mode
opponent_path = job['checkpoint_path'][1]
opponent_policy.load_state_dict(torch.load(opponent_path, map_location='cpu'))
opponent_policy_info = {
'policy': opponent_policy,
'policy_id': opponent_player_id,
'policy_type': 'historical'
}
elif 'bot' in opponent_player_id:
opponent_policy = policies['bot'].collect_mode
opponent_policy_info = {
'policy': opponent_policy,
'policy_id': opponent_player_id,
'policy_type': 'bot'
}
else:
opponent_policy = policies[opponent_player_id].collect_mode
opponent_policy_info = {
'policy': opponent_policy,
'policy_id': opponent_player_id,
'policy_type': 'main'
}

collector.reset_policy([policies[player_id].collect_mode, opponent_policy_info])

collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
collect_kwargs['temperature'] = visit_count_temperature(
cfg.policy.manual_temperature_decay,
cfg.policy.fixed_temperature_value,
cfg.policy.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
)

new_data, episode_info = collector.collect(
train_iter=learner.train_iter, n_episode=cfg.policy.n_episode, policy_kwargs=collect_kwargs
)
new_data_0 = sum(new_data[0], [])
new_data_1 = sum(new_data[1], [])
new_data = new_data_0 + new_data_1

replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
# Learn policy from collected data
for i in range(cfg.policy.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
train_data = replay_buffer.sample(cfg.policy.batch_size, learner.train_iter)
if train_data is None:
logging.warning(
'The data in replay_buffer is not sufficient to sample a mini-batch.'
'continue to collect now ....'
)
break
learner.train(train_data, collector.envstep)

# update the learner_step for the current active player, i.e. the main player in most cases.
player_info = learner.learn_info
player_info['player_id'] = player_id
league.update_active_player(player_info)

# player_info['state_dict'] = policies[player_id].learn_mode.state_dict()

league.judge_snapshot(player_id)
# set eval_flag=True to enable trueskill update

win_loss_result = win_loss_draw(episode_info[0])

job_finish_info = {
'eval_flag': True,
'launch_player': job['launch_player'],
'player_id': job['player_id'],
'result': win_loss_result,
}
league.finish_job(job_finish_info, league_iter)

if league_iter % cfg.policy.league.log_freq_for_payoff_rank == 0:
payoff_string = repr(league.payoff)
rank_string = league.player_rank(string=True)
tb_logger.add_text('payoff_step', payoff_string, main_collector.envstep)
tb_logger.add_text('rank_step', rank_string, main_collector.envstep)

league_iter += 1

if main_collector.envstep >= max_env_step or main_learner.train_iter >= max_train_iter:
break
Loading
Loading