-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(yzj): add multi-agent and structured observation env (GoBigger) #39
base: main
Are you sure you want to change the base?
Changes from all commits
ec0ba9d
2c29842
335b0fc
0875e74
d88d79c
17992eb
4925d01
b8e044e
f229b6a
4bbbeb0
7529170
85aeacf
f146c4d
47b145e
2772ffd
e36e752
7098899
b94deae
dfa4671
ff11821
a95c19c
6da2997
a2ca5ee
249d88a
8c4c5a0
377f664
1ed22b2
35e7714
4df3ada
272611f
cc54996
39802f5
58281d6
e1ba071
c29abaf
e4667df
b6dca69
09a4440
407329a
592fab1
3392d61
9337ce3
deab811
705b5f9
43b2bb5
3d88a17
a09517a
714ba4b
0ee0122
71ce58e
05c025d
5bec18b
5d310ba
920dc38
1c1fde9
72c669b
11ef08f
1e143bc
3e1e62f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
import torch.nn as nn | ||
from ding.torch_utils import MLP | ||
from ding.utils import MODEL_REGISTRY, SequenceType | ||
from ding.utils.default_helper import get_shape0 | ||
|
||
|
||
from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP | ||
from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean | ||
|
@@ -34,6 +36,7 @@ def __init__( | |
discrete_action_encoding_type: str = 'one_hot', | ||
norm_type: Optional[str] = 'BN', | ||
res_connection_in_dynamics: bool = False, | ||
state_encoder=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 增加state_encoder的Type Hints以及相应的arguments注释 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://aicarrier.feishu.cn/wiki/N4bqwLRO5iyQcAkb4HCcflbgnpR 可以参考这里的提示词优化注释哈 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
*args, | ||
**kwargs | ||
): | ||
|
@@ -66,6 +69,7 @@ def __init__( | |
- discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'. | ||
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. | ||
- res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False. | ||
- state_encoder (:obj:`Optional[nn.Module]`): The state encoder network, which is used to encode the raw observation to latent state. | ||
""" | ||
super(MuZeroModelMLP, self).__init__() | ||
self.categorical_distribution = categorical_distribution | ||
|
@@ -101,9 +105,12 @@ def __init__( | |
self.state_norm = state_norm | ||
self.res_connection_in_dynamics = res_connection_in_dynamics | ||
|
||
self.representation_network = RepresentationNetworkMLP( | ||
observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type | ||
) | ||
if state_encoder == None: | ||
self.representation_network = RepresentationNetworkMLP( | ||
observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type | ||
) | ||
else: | ||
self.representation_network = state_encoder | ||
|
||
self.dynamics_network = DynamicsNetwork( | ||
action_encoding_dim=self.action_encoding_dim, | ||
|
@@ -166,7 +173,7 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput: | |
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. | ||
- latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state. | ||
""" | ||
batch_size = obs.size(0) | ||
batch_size = get_shape0(obs) | ||
latent_state = self._representation(obs) | ||
policy_logits, value = self._prediction(latent_state) | ||
return MZNetworkOutput( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,13 @@ | |
DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, \ | ||
prepare_obs, \ | ||
configure_optimizers | ||
<<<<<<< HEAD | ||
from collections import defaultdict | ||
from ding.torch_utils import to_device, to_tensor | ||
from ding.utils.data import default_collate | ||
======= | ||
from lzero.policy.muzero import MuZeroPolicy | ||
>>>>>>> origin | ||
|
||
|
||
@POLICY_REGISTRY.register('efficientzero') | ||
|
@@ -191,6 +197,9 @@ class EfficientZeroPolicy(MuZeroPolicy): | |
# (int) The decay steps from start to end eps. | ||
decay=int(1e5), | ||
), | ||
|
||
# (bool) Whether it is a multi-agent environment. | ||
multi_agent=False, | ||
) | ||
|
||
def default_model(self) -> Tuple[str, List[str]]: | ||
|
@@ -309,7 +318,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: | |
|
||
target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) | ||
target_value = target_value.view(self._cfg.batch_size, -1) | ||
assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) | ||
assert self._cfg.batch_size == target_value_prefix.size(0) | ||
|
||
# ``scalar_transform`` to transform the original value to the scaled value, | ||
# i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. | ||
|
@@ -395,9 +404,40 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: | |
# calculate consistency loss for the next ``num_unroll_steps`` unroll steps. | ||
# ============================================================== | ||
if self._cfg.ssl_loss_weight > 0: | ||
# obtain the oracle latent states from representation function. | ||
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) | ||
network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) | ||
# obtain the oracle hidden states from representation function. | ||
if self._cfg.model.model_type == 'conv': | ||
beg_index = self._cfg.model.image_channel * step_i | ||
end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num) | ||
network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index, :, :]) | ||
elif self._cfg.model.model_type == 'mlp': | ||
beg_index = self._cfg.model.observation_shape * step_i | ||
end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num) | ||
network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) | ||
elif self._cfg.model.model_type == 'structure': | ||
obs_target_batch_new = {} | ||
for k, v in obs_target_batch.items(): | ||
if k == 'action_mask': | ||
obs_target_batch_new[k] = v | ||
continue | ||
if isinstance(v, dict): | ||
obs_target_batch_new[k] = {} | ||
for k1, v1 in v.items(): | ||
if len(v1.shape) == 1: | ||
observation_shape = v1.shape[0]//self._cfg.num_unroll_steps | ||
beg_index = observation_shape * step_i | ||
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) | ||
obs_target_batch_new[k][k1] = v1[beg_index:end_index] | ||
else: | ||
observation_shape = v1.shape[1]//self._cfg.num_unroll_steps | ||
beg_index = observation_shape * step_i | ||
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) | ||
obs_target_batch_new[k][k1] = v1[:, beg_index:end_index] | ||
else: | ||
observation_shape = v.shape[1]//self._cfg.num_unroll_steps | ||
beg_index = observation_shape * step_i | ||
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num) | ||
obs_target_batch_new[k] = v[:, beg_index:end_index] | ||
network_output = self._learn_model.initial_inference(obs_target_batch_new) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 上面对结构化观察的处理或许可以抽象为一个函数 |
||
|
||
latent_state = to_tensor(latent_state) | ||
representation_state = to_tensor(network_output.latent_state) | ||
|
@@ -735,6 +775,7 @@ def _monitor_vars_learn(self) -> List[str]: | |
""" | ||
return [ | ||
'collect_mcts_temperature', | ||
'collect_epsilon', | ||
'cur_lr', | ||
'weighted_total_loss', | ||
'total_loss', | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
抽象为一个数据处理函数,放在utils中?