Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(yzj): add multi-agent and structured observation env (GoBigger) #39

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ec0ba9d
feature(yzj): adapt multi agent env gobigger with ez
May 31, 2023
2c29842
fix(yzj): fix data device bug in gobigger ez pipeline
Jun 1, 2023
335b0fc
feature(yzj): add vsbot with ez pipeline and add eat-info in tensorboard
Jun 1, 2023
0875e74
feature(yzj): add vsbot with mz pipeline and polish model and buffer
Jun 1, 2023
d88d79c
polish(yzj): polish gobigger env
Jun 2, 2023
17992eb
feature(yzj): adapt multi agent env gobigger with sez
Jun 2, 2023
4925d01
feature(yzj): add gobigger visualization and polish gobigger eval config
Jun 7, 2023
b8e044e
fix(yzj): fix eval_episode_return and polish env
Jun 7, 2023
f229b6a
polish(yzj): polish gobigger env pytest
Jun 7, 2023
4bbbeb0
polish(yzj): polish gobigger env and eat info in evaluator
Jun 7, 2023
7529170
fix(yzj): fix np.pad bug, which need padding_num>0
Jun 12, 2023
85aeacf
polish(yzj): contain raw obs only on eval mode for save memory
Jun 13, 2023
f146c4d
fix(yzj): fix mcts ptree sampled value/value-prefix bug
Jun 13, 2023
47b145e
polish(yzj): polish gobigger encoder model
Jun 15, 2023
2772ffd
polish(yzj): polish gobigger encoder model with ding
Jun 16, 2023
e36e752
polish(yzj): polish gobigger entry evaluator
Jun 19, 2023
7098899
feature(yzj): add eps_greedy and random_collect_episode in gobigger ez
Jun 19, 2023
b94deae
fix(yzj): fix key bug in entry utils when random collect
Jun 20, 2023
dfa4671
fix(yzj): fix gobigger encoder bn bug
Jun 25, 2023
ff11821
polish(yzj): polish ez config and set eps as 1.5e4 learner iter
Jun 25, 2023
a95c19c
polish(yzj): polish code style by format.sh
Jun 25, 2023
6da2997
polish(yzj): polish code comments about gobigger in worker/policy/entry
Jun 25, 2023
a2ca5ee
feature(yzj): add eps_greedy and random_collect_episode in gobigger mz
Jun 25, 2023
249d88a
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
Jun 25, 2023
8c4c5a0
polish(yzj): polish entry/buffer/policy/config/model/env comments and…
Jun 28, 2023
377f664
polish(yzj): use ding scatter_model, muzero_collector add multi_agent…
Jun 30, 2023
1ed22b2
fix(yzj): fix collector bug that observation_window_stack no for_loop…
Jul 3, 2023
35e7714
fix(yzj): fix ignore done in collector
Jul 5, 2023
4df3ada
polish(yzj): polish ez config ignore done
Jul 5, 2023
272611f
fix(yzj): add game_segment_pool clear()
Jul 5, 2023
cc54996
polish(yzj): add gobigger/entry , polish gobigger config and add defa…
Jul 12, 2023
39802f5
polish(yzj): polish eps greedy and random policy
Jul 17, 2023
58281d6
fix(yzj): fix random collect in gobigger ez policy
Jul 17, 2023
e1ba071
polish(yzj): merge main-branch eps and random collect, polish gobigge…
Aug 2, 2023
c29abaf
feature(yzj): add peetingzoo mz/ez algo, add multi agent buffer/polic…
Aug 4, 2023
e4667df
polish(yzj): polish multi agent muzero collector
Aug 4, 2023
b6dca69
polish(yzj): polish gobigger collector and config to support t2p3
Aug 8, 2023
09a4440
feature(yzj): add fc encoder on ptz env instead of identity
Aug 8, 2023
407329a
polish(yzj): polish buffer name and remove ignore done in atari config
Aug 10, 2023
592fab1
fix(yzj): fix ssl data bug and polish to_device code
Aug 14, 2023
3392d61
fix(yzj): fix policy utils obs batch
Aug 14, 2023
9337ce3
fix(yzj): fix collect mode and eval mode to device
Aug 14, 2023
deab811
fix(yzj): fix to device bug on policy utils
Aug 15, 2023
705b5f9
polish(yzj): polish multi agent game buffer code
Aug 15, 2023
43b2bb5
polish(yzj): polish code
Aug 15, 2023
3d88a17
fix(yzj): fix priority bug, polish priority related config, add all a…
Aug 15, 2023
a09517a
polish(yzj): polish train entry
Aug 15, 2023
714ba4b
polish(yzj): polish gobigger config
Aug 16, 2023
0ee0122
polish(yzj): polish best gobigger config on ez/mz
Aug 18, 2023
71ce58e
polish(yzj): polish collector to adapt multi-agent mode
Aug 18, 2023
05c025d
polish(yzj): polish evaluator conflicts
Aug 18, 2023
5bec18b
polish(yzj): polish multi agent model
Aug 18, 2023
5d310ba
polish(yzj): sync main
Aug 21, 2023
920dc38
polish(yzj): polish gobigger entry and evaluator
Aug 21, 2023
1c1fde9
feature(yzj): add pettingzoo visualization
Aug 29, 2023
72c669b
polish(yzj): polish ptz config and model
Aug 29, 2023
11ef08f
feature(yzj): add ptz simple ez config
Sep 4, 2023
1e143bc
polish(yzj): polish code base
jayyoung0802 Dec 7, 2023
3e1e62f
Merge remote-tracking branch 'origin' into dev-gobigger
jayyoung0802 Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .train_muzero import train_muzero
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
from .eval_muzero import eval_muzero
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
25 changes: 19 additions & 6 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect


Expand Down Expand Up @@ -47,8 +44,8 @@ def train_muzero(
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

合并一下main分支,将mz ez的相关基线结果加在PR的description里面。然后优化好后新建一个分支 multi-agent, push到opendilab/lightzero 上去,在这个PR后面写一下,最新的稳定代码放在了 multi-agent 这个分支上面。


cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'"
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'multi_agent_efficientzero', 'multi_agent_muzero'"

if create_cfg.policy.type == 'muzero':
from lzero.mcts import MuZeroGameBuffer as GameBuffer
Expand All @@ -58,6 +55,10 @@ def train_muzero(
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gumbel_muzero':
from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'multi_agent_muzero':
from lzero.mcts import MultiAgentMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'multi_agent_efficientzero':
from lzero.mcts import MultiAgentSampledEfficientZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
Expand Down Expand Up @@ -92,6 +93,14 @@ def train_muzero(
batch_size = policy_config.batch_size
# specific game buffer for MCTS+RL algorithms
replay_buffer = GameBuffer(policy_config)

if policy_config.multi_agent:
from lzero.worker import MultiAgentMuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
else:
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator

collector = Collector(
env=collector_env,
policy=policy.collect_mode,
Expand Down Expand Up @@ -123,7 +132,11 @@ def train_muzero(
# Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy.
# Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
if policy_config.multi_agent:
from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy as RandomPolicy
else:
from lzero.policy.random_policy import LightZeroRandomPolicy as RandomPolicy
random_collect(cfg.policy, policy, RandomPolicy, collector, collector_env, replay_buffer)

while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
Expand Down
3 changes: 2 additions & 1 deletion lzero/entry/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import Optional, Callable

import psutil
from easydict import EasyDict
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter
from typing import Optional, Callable
Expand Down Expand Up @@ -39,7 +41,6 @@ def random_collect(
# restore the policy
collector.reset_policy(policy.collect_mode)


def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
"""
Overview:
Expand Down
2 changes: 2 additions & 0 deletions lzero/mcts/buffer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
from .game_buffer_efficientzero import EfficientZeroGameBuffer
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
from .multi_agent_game_buffer_muzero import MultiAgentMuZeroGameBuffer
from .multi_agent_game_buffer_efficientzero import MultiAgentSampledEfficientZeroGameBuffer
8 changes: 5 additions & 3 deletions lzero/mcts/buffer/game_buffer_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, cfg: dict):
self.base_idx = 0
self.clear_time = 0

self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

优化注释,注释尽量完整清晰

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def sample(self, batch_size: int, policy: Any) -> List[Any]:
"""
Overview:
Expand Down Expand Up @@ -100,7 +102,6 @@ def _prepare_reward_value_context(
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
td_steps_list, action_mask_segment, to_play_segment
"""
zero_obs = game_segment_list[0].zero_obs()
value_obs_list = []
# the value is valid or not (out of trajectory)
value_mask = []
Expand Down Expand Up @@ -148,11 +149,12 @@ def _prepare_reward_value_context(
end_index = beg_index + self._cfg.model.frame_stack_num
# the stacked obs in time t
obs = game_obs[beg_index:end_index]
self.tmp_obs = obs # will be masked
else:
value_mask.append(0)
obs = zero_obs
obs = self.tmp_obs # will be masked

value_obs_list.append(obs)
value_obs_list.append(obs.tolist())

reward_value_context = [
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list,
Expand Down
8 changes: 5 additions & 3 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, cfg: dict):
self.game_pos_priorities = []
self.game_segment_game_pos_look_up = []

self.tmp_obs = None # a tmp value which records obs when value obs list [current_index + 4(td_step)] > 50(game_segment)

def sample(
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
) -> List[Any]:
Expand Down Expand Up @@ -198,7 +200,6 @@ def _prepare_reward_value_context(
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
td_steps_list, action_mask_segment, to_play_segment
"""
zero_obs = game_segment_list[0].zero_obs()
value_obs_list = []
# the value is valid or not (out of game_segment)
value_mask = []
Expand Down Expand Up @@ -238,11 +239,12 @@ def _prepare_reward_value_context(
end_index = beg_index + self._cfg.model.frame_stack_num
# the stacked obs in time t
obs = game_obs[beg_index:end_index]
self.tmp_obs = obs # will be masked
else:
value_mask.append(0)
obs = zero_obs
obs = self.tmp_obs # will be masked

value_obs_list.append(obs)
value_obs_list.append(obs.tolist())

reward_value_context = [
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list,
Expand Down
Loading