-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(yzj): add multi-agent and structured observation env (GoBigger) #39
base: main
Are you sure you want to change the base?
Changes from 56 commits
ec0ba9d
2c29842
335b0fc
0875e74
d88d79c
17992eb
4925d01
b8e044e
f229b6a
4bbbeb0
7529170
85aeacf
f146c4d
47b145e
2772ffd
e36e752
7098899
b94deae
dfa4671
ff11821
a95c19c
6da2997
a2ca5ee
249d88a
8c4c5a0
377f664
1ed22b2
35e7714
4df3ada
272611f
cc54996
39802f5
58281d6
e1ba071
c29abaf
e4667df
b6dca69
09a4440
407329a
592fab1
3392d61
9337ce3
deab811
705b5f9
43b2bb5
3d88a17
a09517a
714ba4b
0ee0122
71ce58e
05c025d
5bec18b
5d310ba
920dc38
1c1fde9
72c669b
11ef08f
1e143bc
3e1e62f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .game_buffer_muzero import MuZeroGameBuffer | ||
from .game_buffer_efficientzero import EfficientZeroGameBuffer | ||
from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer | ||
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer | ||
from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ | |
from lzero.mcts.utils import prepare_observation | ||
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform | ||
from .game_buffer_muzero import MuZeroGameBuffer | ||
from ding.torch_utils import to_device, to_tensor, to_ndarray | ||
from ding.utils.data import default_collate | ||
|
||
|
||
@BUFFER_REGISTRY.register('game_buffer_efficientzero') | ||
|
@@ -44,6 +46,8 @@ def __init__(self, cfg: dict): | |
self.base_idx = 0 | ||
self.clear_time = 0 | ||
|
||
self.tmp_obs = None # for value obs list [46 + 4(td_step)] not < 50(game_segment) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 优化注释,注释尽量完整清晰 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
def sample(self, batch_size: int, policy: Any) -> List[Any]: | ||
""" | ||
Overview: | ||
|
@@ -100,7 +104,6 @@ def _prepare_reward_value_context( | |
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, | ||
td_steps_list, action_mask_segment, to_play_segment | ||
""" | ||
zero_obs = game_segment_list[0].zero_obs() | ||
value_obs_list = [] | ||
# the value is valid or not (out of trajectory) | ||
value_mask = [] | ||
|
@@ -148,11 +151,12 @@ def _prepare_reward_value_context( | |
end_index = beg_index + self._cfg.model.frame_stack_num | ||
# the stacked obs in time t | ||
obs = game_obs[beg_index:end_index] | ||
self.tmp_obs = obs # will be masked | ||
else: | ||
value_mask.append(0) | ||
obs = zero_obs | ||
obs = self.tmp_obs # will be masked | ||
|
||
value_obs_list.append(obs) | ||
value_obs_list.append(obs.tolist()) | ||
|
||
reward_value_context = [ | ||
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, | ||
|
@@ -196,7 +200,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A | |
beg_index = self._cfg.mini_infer_size * i | ||
end_index = self._cfg.mini_infer_size * (i + 1) | ||
|
||
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() | ||
if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']: | ||
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() | ||
elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure': | ||
m_obs = value_obs_list[beg_index:end_index] | ||
m_obs = sum(m_obs, []) | ||
m_obs = default_collate(m_obs) | ||
m_obs = to_device(m_obs, self._cfg.device) | ||
|
||
# calculate the target value | ||
m_output = model.initial_inference(m_obs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ | |
from lzero.mcts.utils import prepare_observation | ||
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform | ||
from .game_buffer import GameBuffer | ||
from ding.torch_utils import to_device, to_tensor | ||
from ding.utils.data import default_collate | ||
|
||
if TYPE_CHECKING: | ||
from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy | ||
|
@@ -48,6 +50,8 @@ def __init__(self, cfg: dict): | |
self.game_pos_priorities = [] | ||
self.game_segment_game_pos_look_up = [] | ||
|
||
self.tmp_obs = None # a tmp value which records obs when value obs list [current_index + 4(td_step)] > 50(game_segment) | ||
|
||
def sample( | ||
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] | ||
) -> List[Any]: | ||
|
@@ -198,7 +202,6 @@ def _prepare_reward_value_context( | |
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, | ||
td_steps_list, action_mask_segment, to_play_segment | ||
""" | ||
zero_obs = game_segment_list[0].zero_obs() | ||
value_obs_list = [] | ||
# the value is valid or not (out of game_segment) | ||
value_mask = [] | ||
|
@@ -238,11 +241,12 @@ def _prepare_reward_value_context( | |
end_index = beg_index + self._cfg.model.frame_stack_num | ||
# the stacked obs in time t | ||
obs = game_obs[beg_index:end_index] | ||
self.tmp_obs = obs # will be masked | ||
else: | ||
value_mask.append(0) | ||
obs = zero_obs | ||
obs = self.tmp_obs # will be masked | ||
|
||
value_obs_list.append(obs) | ||
value_obs_list.append(obs.tolist()) | ||
|
||
reward_value_context = [ | ||
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, | ||
|
@@ -376,8 +380,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A | |
for i in range(slices): | ||
beg_index = self._cfg.mini_infer_size * i | ||
end_index = self._cfg.mini_infer_size * (i + 1) | ||
|
||
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() | ||
|
||
if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']: | ||
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float() | ||
elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure': | ||
m_obs = value_obs_list[beg_index:end_index] | ||
m_obs = sum(m_obs, []) | ||
m_obs = default_collate(m_obs) | ||
m_obs = to_device(m_obs, self._cfg.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 抽象为一个数据处理函数,放在utils中? |
||
|
||
# calculate the target value | ||
m_output = model.initial_inference(m_obs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
import torch.nn as nn | ||
from ding.torch_utils import MLP | ||
from ding.utils import MODEL_REGISTRY, SequenceType | ||
from ding.utils.default_helper import get_shape0 | ||
|
||
|
||
from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP | ||
from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean | ||
|
@@ -34,6 +36,7 @@ def __init__( | |
discrete_action_encoding_type: str = 'one_hot', | ||
norm_type: Optional[str] = 'BN', | ||
res_connection_in_dynamics: bool = False, | ||
state_encoder=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 增加state_encoder的Type Hints以及相应的arguments注释 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://aicarrier.feishu.cn/wiki/N4bqwLRO5iyQcAkb4HCcflbgnpR 可以参考这里的提示词优化注释哈 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
*args, | ||
**kwargs | ||
): | ||
|
@@ -101,9 +104,12 @@ def __init__( | |
self.state_norm = state_norm | ||
self.res_connection_in_dynamics = res_connection_in_dynamics | ||
|
||
self.representation_network = RepresentationNetworkMLP( | ||
observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type | ||
) | ||
if state_encoder == None: | ||
self.representation_network = RepresentationNetworkMLP( | ||
observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type | ||
) | ||
else: | ||
self.representation_network = state_encoder | ||
|
||
self.dynamics_network = DynamicsNetwork( | ||
action_encoding_dim=self.action_encoding_dim, | ||
|
@@ -166,7 +172,7 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: | |
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. | ||
- latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. | ||
""" | ||
batch_size = obs.size(0) | ||
batch_size = get_shape0(obs) | ||
latent_state = self._representation(obs) | ||
policy_logits, value = self._prediction(latent_state) | ||
return MZNetworkOutput( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,9 @@ | |
from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ | ||
DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ | ||
configure_optimizers | ||
from collections import defaultdict | ||
from ding.torch_utils import to_device, to_tensor | ||
from ding.utils.data import default_collate | ||
|
||
|
||
@POLICY_REGISTRY.register('efficientzero') | ||
|
@@ -186,6 +189,9 @@ class EfficientZeroPolicy(Policy): | |
# (int) The decay steps from start to end eps. | ||
decay=int(1e5), | ||
), | ||
|
||
# (bool) Whether it is a multi-agent environment. | ||
multi_agent=False, | ||
) | ||
|
||
def default_model(self) -> Tuple[str, List[str]]: | ||
|
@@ -302,7 +308,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: | |
|
||
target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) | ||
target_value = target_value.view(self._cfg.batch_size, -1) | ||
assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) | ||
assert self._cfg.batch_size == target_value_prefix.size(0) | ||
|
||
# ``scalar_transform`` to transform the original value to the scaled value, | ||
# i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. | ||
|
@@ -397,6 +403,31 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: | |
beg_index = self._cfg.model.observation_shape * step_i | ||
end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) | ||
network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) | ||
elif self._cfg.model.model_type == 'structure': | ||
obs_target_batch_new = {} | ||
for k, v in obs_target_batch.items(): | ||
if k == 'action_mask': | ||
obs_target_batch_new[k] = v | ||
continue | ||
if isinstance(v, dict): | ||
obs_target_batch_new[k] = {} | ||
for k1, v1 in v.items(): | ||
if len(v1.shape) == 1: | ||
observation_shape = v1.shape[0]//self._cfg.num_unroll_steps | ||
beg_index = observation_shape * step_i | ||
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) | ||
obs_target_batch_new[k][k1] = v1[beg_index:end_index] | ||
else: | ||
observation_shape = v1.shape[1]//self._cfg.num_unroll_steps | ||
beg_index = observation_shape * step_i | ||
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) | ||
obs_target_batch_new[k][k1] = v1[:, beg_index:end_index] | ||
else: | ||
observation_shape = v.shape[1]//self._cfg.num_unroll_steps | ||
beg_index = observation_shape * step_i | ||
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) | ||
obs_target_batch_new[k] = v[:, beg_index:end_index] | ||
network_output = self._learn_model.initial_inference(obs_target_batch_new) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 上面对结构化观察的处理或许可以抽象为一个函数 |
||
|
||
latent_state = to_tensor(latent_state) | ||
representation_state = to_tensor(network_output.latent_state) | ||
|
@@ -735,6 +766,7 @@ def _monitor_vars_learn(self) -> List[str]: | |
""" | ||
return [ | ||
'collect_mcts_temperature', | ||
'collect_epsilon', | ||
'cur_lr', | ||
'weighted_total_loss', | ||
'total_loss', | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
合并一下main分支,将mz ez的相关基线结果加在PR的description里面。然后优化好后新建一个分支 multi-agent, push到opendilab/lightzero 上去,在这个PR后面写一下,最新的稳定代码放在了 multi-agent 这个分支上面。