Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(yzj): add multi-agent and structured observation env (GoBigger) #39

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ec0ba9d
feature(yzj): adapt multi agent env gobigger with ez
May 31, 2023
2c29842
fix(yzj): fix data device bug in gobigger ez pipeline
Jun 1, 2023
335b0fc
feature(yzj): add vsbot with ez pipeline and add eat-info in tensorboard
Jun 1, 2023
0875e74
feature(yzj): add vsbot with mz pipeline and polish model and buffer
Jun 1, 2023
d88d79c
polish(yzj): polish gobigger env
Jun 2, 2023
17992eb
feature(yzj): adapt multi agent env gobigger with sez
Jun 2, 2023
4925d01
feature(yzj): add gobigger visualization and polish gobigger eval config
Jun 7, 2023
b8e044e
fix(yzj): fix eval_episode_return and polish env
Jun 7, 2023
f229b6a
polish(yzj): polish gobigger env pytest
Jun 7, 2023
4bbbeb0
polish(yzj): polish gobigger env and eat info in evaluator
Jun 7, 2023
7529170
fix(yzj): fix np.pad bug, which need padding_num>0
Jun 12, 2023
85aeacf
polish(yzj): contain raw obs only on eval mode for save memory
Jun 13, 2023
f146c4d
fix(yzj): fix mcts ptree sampled value/value-prefix bug
Jun 13, 2023
47b145e
polish(yzj): polish gobigger encoder model
Jun 15, 2023
2772ffd
polish(yzj): polish gobigger encoder model with ding
Jun 16, 2023
e36e752
polish(yzj): polish gobigger entry evaluator
Jun 19, 2023
7098899
feature(yzj): add eps_greedy and random_collect_episode in gobigger ez
Jun 19, 2023
b94deae
fix(yzj): fix key bug in entry utils when random collect
Jun 20, 2023
dfa4671
fix(yzj): fix gobigger encoder bn bug
Jun 25, 2023
ff11821
polish(yzj): polish ez config and set eps as 1.5e4 learner iter
Jun 25, 2023
a95c19c
polish(yzj): polish code style by format.sh
Jun 25, 2023
6da2997
polish(yzj): polish code comments about gobigger in worker/policy/entry
Jun 25, 2023
a2ca5ee
feature(yzj): add eps_greedy and random_collect_episode in gobigger mz
Jun 25, 2023
249d88a
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
Jun 25, 2023
8c4c5a0
polish(yzj): polish entry/buffer/policy/config/model/env comments and…
Jun 28, 2023
377f664
polish(yzj): use ding scatter_model, muzero_collector add multi_agent…
Jun 30, 2023
1ed22b2
fix(yzj): fix collector bug that observation_window_stack no for_loop…
Jul 3, 2023
35e7714
fix(yzj): fix ignore done in collector
Jul 5, 2023
4df3ada
polish(yzj): polish ez config ignore done
Jul 5, 2023
272611f
fix(yzj): add game_segment_pool clear()
Jul 5, 2023
cc54996
polish(yzj): add gobigger/entry , polish gobigger config and add defa…
Jul 12, 2023
39802f5
polish(yzj): polish eps greedy and random policy
Jul 17, 2023
58281d6
fix(yzj): fix random collect in gobigger ez policy
Jul 17, 2023
e1ba071
polish(yzj): merge main-branch eps and random collect, polish gobigge…
Aug 2, 2023
c29abaf
feature(yzj): add peetingzoo mz/ez algo, add multi agent buffer/polic…
Aug 4, 2023
e4667df
polish(yzj): polish multi agent muzero collector
Aug 4, 2023
b6dca69
polish(yzj): polish gobigger collector and config to support t2p3
Aug 8, 2023
09a4440
feature(yzj): add fc encoder on ptz env instead of identity
Aug 8, 2023
407329a
polish(yzj): polish buffer name and remove ignore done in atari config
Aug 10, 2023
592fab1
fix(yzj): fix ssl data bug and polish to_device code
Aug 14, 2023
3392d61
fix(yzj): fix policy utils obs batch
Aug 14, 2023
9337ce3
fix(yzj): fix collect mode and eval mode to device
Aug 14, 2023
deab811
fix(yzj): fix to device bug on policy utils
Aug 15, 2023
705b5f9
polish(yzj): polish multi agent game buffer code
Aug 15, 2023
43b2bb5
polish(yzj): polish code
Aug 15, 2023
3d88a17
fix(yzj): fix priority bug, polish priority related config, add all a…
Aug 15, 2023
a09517a
polish(yzj): polish train entry
Aug 15, 2023
714ba4b
polish(yzj): polish gobigger config
Aug 16, 2023
0ee0122
polish(yzj): polish best gobigger config on ez/mz
Aug 18, 2023
71ce58e
polish(yzj): polish collector to adapt multi-agent mode
Aug 18, 2023
05c025d
polish(yzj): polish evaluator conflicts
Aug 18, 2023
5bec18b
polish(yzj): polish multi agent model
Aug 18, 2023
5d310ba
polish(yzj): sync main
Aug 21, 2023
920dc38
polish(yzj): polish gobigger entry and evaluator
Aug 21, 2023
1c1fde9
feature(yzj): add pettingzoo visualization
Aug 29, 2023
72c669b
polish(yzj): polish ptz config and model
Aug 29, 2023
11ef08f
feature(yzj): add ptz simple ez config
Sep 4, 2023
1e143bc
polish(yzj): polish code base
jayyoung0802 Dec 7, 2023
3e1e62f
Merge remote-tracking branch 'origin' into dev-gobigger
jayyoung0802 Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
from .train_muzero import train_muzero
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
from .eval_muzero import eval_muzero
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_gobigger import train_muzero_gobigger
from .eval_muzero_gobigger import eval_muzero_gobigger
117 changes: 117 additions & 0 deletions lzero/entry/eval_muzero_gobigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import os
from functools import partial
from typing import Optional, Tuple
import numpy as np
import torch
from tensorboardX import SummaryWriter
import copy

from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner
from lzero.worker import GoBiggerMuZeroEvaluator


def eval_muzero_gobigger(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
) -> 'Policy': # noqa
"""
Overview:
The eval entry for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
Returns:
- reward_sp (:obj:`List`): reward of self-play mode.
- reward_vsbot (:obj:`List`): reward of vsbot mode.
"""
cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'], \
"train_muzero entry now only support the following algo.: 'gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero'"

if create_cfg.policy.type == 'gobigger_efficientzero':
from lzero.mcts import GoBiggerEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gobigger_muzero':
from lzero.mcts import GoBiggerMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gobigger_sampled_efficientzero':
from lzero.mcts import GoBiggerSampledEfficientZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
else:
cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

env_cfg = copy.deepcopy(evaluator_env_cfg[0])
env_cfg.contain_raw_obs = True
vsbot_evaluator_env_cfg = [env_cfg for _ in range(len(evaluator_env_cfg))]
vsbot_evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in vsbot_evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
vsbot_evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
if model_path is not None:
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = cfg.policy
evaluator = GoBiggerMuZeroEvaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)
vsbot_evaluator = GoBiggerMuZeroEvaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=vsbot_evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config,
instance_name='vsbot_evaluator'
)
# ==============================================================
# Main loop
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')
# ==============================================================
# eval trained model
# ==============================================================
_, reward_sp = evaluator.eval(learner.save_checkpoint, learner.train_iter)
_, reward_vsbot = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter)
return reward_sp, reward_vsbot
193 changes: 193 additions & 0 deletions lzero/entry/train_muzero_gobigger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import logging
import os
from functools import partial
from typing import Optional, Tuple

import torch
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter
import copy
from ding.rl_utils import get_epsilon_greedy_fn
from lzero.entry.utils import log_buffer_memory_usage, random_collect
from lzero.policy import visit_count_temperature
from lzero.worker import GoBiggerMuZeroCollector, GoBiggerMuZeroEvaluator
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved


def train_muzero_gobigger(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
The train entry for GoBigger MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['gobigger_efficientzero', 'gobigger_muzero', 'gobigger_sampled_efficientzero']
if create_cfg.policy.type == 'gobigger_efficientzero':
from lzero.mcts import GoBiggerEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gobigger_muzero':
from lzero.mcts import GoBiggerMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gobigger_sampled_efficientzero':
from lzero.mcts import GoBiggerSampledEfficientZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
else:
cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

env_cfg = copy.deepcopy(evaluator_env_cfg[0])
env_cfg.contain_raw_obs = True
vsbot_evaluator_env_cfg = [env_cfg for _ in range(len(evaluator_env_cfg))]
vsbot_evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in vsbot_evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
vsbot_evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
if model_path is not None:
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = cfg.policy
batch_size = policy_config.batch_size
# specific game buffer for MCTS+RL algorithms
replay_buffer = GameBuffer(policy_config)
collector = GoBiggerMuZeroCollector(
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)
evaluator = GoBiggerMuZeroEvaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)

vsbot_evaluator = GoBiggerMuZeroEvaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=vsbot_evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config,
instance_name='vsbot_evaluator'
)

# ==============================================================
# Main loop
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, collector, collector_env, replay_buffer)
# reset the random_collect_episode_num to 0
cfg.policy.random_collect_episode_num = 0

while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
collect_kwargs['temperature'] = visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
)
if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
else:
collect_kwargs['epsilon'] = 0.0
# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
stop, reward = vsbot_evaluator.eval_vsbot(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
replay_buffer.remove_oldest_data_to_fit()

# Learn policy from collected data.
for i in range(cfg.policy.update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
if replay_buffer.get_num_of_transitions() > batch_size:
train_data = replay_buffer.sample(batch_size, policy)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, '
f'{replay_buffer} '
f'continue to collect now ....'
)
break

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

# Learner's after_run hook.
learner.call_hook('after_run')
return policy
50 changes: 50 additions & 0 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,59 @@
import os
from typing import Optional, Callable

import psutil
from easydict import EasyDict
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter

from lzero.policy.gobigger_random_policy import GoBiggerRandomPolicy


def random_collect(
policy_cfg: 'EasyDict', # noqa
policy: 'Policy', # noqa
collector: 'ISerialCollector', # noqa
collector_env: 'BaseEnvManager', # noqa
replay_buffer: 'IBuffer', # noqa
postprocess_data_fn: Optional[Callable] = None
) -> None: # noqa
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
"""
Overview:
Collect data by random policy.
Arguments:
- policy_cfg (:obj:`EasyDict`): The policy config.
- policy (:obj:`Policy`): The policy.
- collector (:obj:`ISerialCollector`): The collector.
- collector_env (:obj:`BaseEnvManager`): The collector env manager.
- replay_buffer (:obj:`IBuffer`): The replay buffer.
- postprocess_data_fn (:obj:`Optional[Callable]`): The postprocess function for the collected data.
"""
assert policy_cfg.random_collect_episode_num > 0

random_policy = GoBiggerRandomPolicy(cfg=policy_cfg)
# set the policy to random policy
collector.reset_policy(random_policy.collect_mode)

collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
collect_kwargs['temperature'] = 1
collect_kwargs['epsilon'] = 0.0

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=0, policy_kwargs=collect_kwargs)

if postprocess_data_fn is not None:
new_data = postprocess_data_fn(new_data)

# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
replay_buffer.remove_oldest_data_to_fit()

# restore the policy
collector.reset_policy(policy.collect_mode)


def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
"""
Expand Down
3 changes: 3 additions & 0 deletions lzero/mcts/buffer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .game_buffer_muzero import MuZeroGameBuffer
from .game_buffer_efficientzero import EfficientZeroGameBuffer
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .gobigger_game_buffer_muzero import GoBiggerMuZeroGameBuffer
from .gobigger_game_buffer_efficientzero import GoBiggerEfficientZeroGameBuffer
from .gobigger_game_buffer_sampled_efficientzero import GoBiggerSampledEfficientZeroGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
Loading