Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(yzj): add multi-agent and structured observation env (GoBigger) #39

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ec0ba9d
feature(yzj): adapt multi agent env gobigger with ez
May 31, 2023
2c29842
fix(yzj): fix data device bug in gobigger ez pipeline
Jun 1, 2023
335b0fc
feature(yzj): add vsbot with ez pipeline and add eat-info in tensorboard
Jun 1, 2023
0875e74
feature(yzj): add vsbot with mz pipeline and polish model and buffer
Jun 1, 2023
d88d79c
polish(yzj): polish gobigger env
Jun 2, 2023
17992eb
feature(yzj): adapt multi agent env gobigger with sez
Jun 2, 2023
4925d01
feature(yzj): add gobigger visualization and polish gobigger eval config
Jun 7, 2023
b8e044e
fix(yzj): fix eval_episode_return and polish env
Jun 7, 2023
f229b6a
polish(yzj): polish gobigger env pytest
Jun 7, 2023
4bbbeb0
polish(yzj): polish gobigger env and eat info in evaluator
Jun 7, 2023
7529170
fix(yzj): fix np.pad bug, which need padding_num>0
Jun 12, 2023
85aeacf
polish(yzj): contain raw obs only on eval mode for save memory
Jun 13, 2023
f146c4d
fix(yzj): fix mcts ptree sampled value/value-prefix bug
Jun 13, 2023
47b145e
polish(yzj): polish gobigger encoder model
Jun 15, 2023
2772ffd
polish(yzj): polish gobigger encoder model with ding
Jun 16, 2023
e36e752
polish(yzj): polish gobigger entry evaluator
Jun 19, 2023
7098899
feature(yzj): add eps_greedy and random_collect_episode in gobigger ez
Jun 19, 2023
b94deae
fix(yzj): fix key bug in entry utils when random collect
Jun 20, 2023
dfa4671
fix(yzj): fix gobigger encoder bn bug
Jun 25, 2023
ff11821
polish(yzj): polish ez config and set eps as 1.5e4 learner iter
Jun 25, 2023
a95c19c
polish(yzj): polish code style by format.sh
Jun 25, 2023
6da2997
polish(yzj): polish code comments about gobigger in worker/policy/entry
Jun 25, 2023
a2ca5ee
feature(yzj): add eps_greedy and random_collect_episode in gobigger mz
Jun 25, 2023
249d88a
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
Jun 25, 2023
8c4c5a0
polish(yzj): polish entry/buffer/policy/config/model/env comments and…
Jun 28, 2023
377f664
polish(yzj): use ding scatter_model, muzero_collector add multi_agent…
Jun 30, 2023
1ed22b2
fix(yzj): fix collector bug that observation_window_stack no for_loop…
Jul 3, 2023
35e7714
fix(yzj): fix ignore done in collector
Jul 5, 2023
4df3ada
polish(yzj): polish ez config ignore done
Jul 5, 2023
272611f
fix(yzj): add game_segment_pool clear()
Jul 5, 2023
cc54996
polish(yzj): add gobigger/entry , polish gobigger config and add defa…
Jul 12, 2023
39802f5
polish(yzj): polish eps greedy and random policy
Jul 17, 2023
58281d6
fix(yzj): fix random collect in gobigger ez policy
Jul 17, 2023
e1ba071
polish(yzj): merge main-branch eps and random collect, polish gobigge…
Aug 2, 2023
c29abaf
feature(yzj): add peetingzoo mz/ez algo, add multi agent buffer/polic…
Aug 4, 2023
e4667df
polish(yzj): polish multi agent muzero collector
Aug 4, 2023
b6dca69
polish(yzj): polish gobigger collector and config to support t2p3
Aug 8, 2023
09a4440
feature(yzj): add fc encoder on ptz env instead of identity
Aug 8, 2023
407329a
polish(yzj): polish buffer name and remove ignore done in atari config
Aug 10, 2023
592fab1
fix(yzj): fix ssl data bug and polish to_device code
Aug 14, 2023
3392d61
fix(yzj): fix policy utils obs batch
Aug 14, 2023
9337ce3
fix(yzj): fix collect mode and eval mode to device
Aug 14, 2023
deab811
fix(yzj): fix to device bug on policy utils
Aug 15, 2023
705b5f9
polish(yzj): polish multi agent game buffer code
Aug 15, 2023
43b2bb5
polish(yzj): polish code
Aug 15, 2023
3d88a17
fix(yzj): fix priority bug, polish priority related config, add all a…
Aug 15, 2023
a09517a
polish(yzj): polish train entry
Aug 15, 2023
714ba4b
polish(yzj): polish gobigger config
Aug 16, 2023
0ee0122
polish(yzj): polish best gobigger config on ez/mz
Aug 18, 2023
71ce58e
polish(yzj): polish collector to adapt multi-agent mode
Aug 18, 2023
05c025d
polish(yzj): polish evaluator conflicts
Aug 18, 2023
5bec18b
polish(yzj): polish multi agent model
Aug 18, 2023
5d310ba
polish(yzj): sync main
Aug 21, 2023
920dc38
polish(yzj): polish gobigger entry and evaluator
Aug 21, 2023
1c1fde9
feature(yzj): add pettingzoo visualization
Aug 29, 2023
72c669b
polish(yzj): polish ptz config and model
Aug 29, 2023
11ef08f
feature(yzj): add ptz simple ez config
Sep 4, 2023
1e143bc
polish(yzj): polish code base
jayyoung0802 Dec 7, 2023
3e1e62f
Merge remote-tracking branch 'origin' into dev-gobigger
jayyoung0802 Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def train_muzero(
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'"

if create_cfg.policy.type == 'muzero':
if create_cfg.policy.type == 'muzero' or create_cfg.policy.type == 'multi_agent_muzero':
from lzero.mcts import MuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'efficientzero':
elif create_cfg.policy.type == 'efficientzero' or create_cfg.policy.type == 'multi_agent_efficientzero':
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_efficientzero':
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
Expand Down Expand Up @@ -125,7 +125,11 @@ def train_muzero(
# Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
# Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
if policy_config.multi_agent:
from lzero.policy.multi_agent_random_policy import MultiAgentLightZeroRandomPolicy as RandomPolicy
else:
from lzero.policy.random_policy import LightZeroRandomPolicy as RandomPolicy
random_collect(cfg.policy, policy, RandomPolicy, collector, collector_env, replay_buffer)

while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
Expand Down Expand Up @@ -192,4 +196,4 @@ def train_muzero(

# Learner's after_run hook.
learner.call_hook('after_run')
return policy
return policy
18 changes: 14 additions & 4 deletions lzero/mcts/buffer/game_buffer_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from lzero.mcts.utils import prepare_observation
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
from .game_buffer_muzero import MuZeroGameBuffer
from ding.torch_utils import to_device, to_tensor, to_ndarray
from ding.utils.data import default_collate


@BUFFER_REGISTRY.register('game_buffer_efficientzero')
Expand Down Expand Up @@ -44,6 +46,8 @@ def __init__(self, cfg: dict):
self.base_idx = 0
self.clear_time = 0

self.tmp_obs = None # since value obs list [46 + 4(td_step)] >= 50(game_segment), need pad

def sample(self, batch_size: int, policy: Any) -> List[Any]:
"""
Overview:
Expand Down Expand Up @@ -100,7 +104,6 @@ def _prepare_reward_value_context(
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
td_steps_list, action_mask_segment, to_play_segment
"""
zero_obs = game_segment_list[0].zero_obs()
value_obs_list = []
# the value is valid or not (out of trajectory)
value_mask = []
Expand Down Expand Up @@ -148,11 +151,12 @@ def _prepare_reward_value_context(
end_index = beg_index + self._cfg.model.frame_stack_num
# the stacked obs in time t
obs = game_obs[beg_index:end_index]
self.tmp_obs = obs # will be masked
else:
value_mask.append(0)
obs = zero_obs
obs = self.tmp_obs # will be masked

value_obs_list.append(obs)
value_obs_list.append(obs.tolist())

reward_value_context = [
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list,
Expand Down Expand Up @@ -196,7 +200,13 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
beg_index = self._cfg.mini_infer_size * i
end_index = self._cfg.mini_infer_size * (i + 1)

m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']:
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure':
m_obs = value_obs_list[beg_index:end_index]
m_obs = sum(m_obs, [])
m_obs = default_collate(m_obs)
m_obs = to_device(m_obs, self._cfg.device)

# calculate the target value
m_output = model.initial_inference(m_obs)
Expand Down
20 changes: 15 additions & 5 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from lzero.mcts.utils import prepare_observation
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
from .game_buffer import GameBuffer
from ding.torch_utils import to_device, to_tensor
from ding.utils.data import default_collate

if TYPE_CHECKING:
from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy
Expand Down Expand Up @@ -48,6 +50,8 @@ def __init__(self, cfg: dict):
self.game_pos_priorities = []
self.game_segment_game_pos_look_up = []

self.tmp_obs = None # a tmp value which records obs when value obs list [current_index + 4(td_step)] > 50(game_segment)

def sample(
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
) -> List[Any]:
Expand Down Expand Up @@ -198,7 +202,6 @@ def _prepare_reward_value_context(
- reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
td_steps_list, action_mask_segment, to_play_segment
"""
zero_obs = game_segment_list[0].zero_obs()
value_obs_list = []
# the value is valid or not (out of game_segment)
value_mask = []
Expand Down Expand Up @@ -238,11 +241,12 @@ def _prepare_reward_value_context(
end_index = beg_index + self._cfg.model.frame_stack_num
# the stacked obs in time t
obs = game_obs[beg_index:end_index]
self.tmp_obs = obs # will be masked
else:
value_mask.append(0)
obs = zero_obs
obs = self.tmp_obs # will be masked

value_obs_list.append(obs)
value_obs_list.append(obs.tolist())

reward_value_context = [
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list,
Expand Down Expand Up @@ -376,8 +380,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
for i in range(slices):
beg_index = self._cfg.mini_infer_size * i
end_index = self._cfg.mini_infer_size * (i + 1)

m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()

if self._cfg.model.model_type and self._cfg.model.model_type in ['conv', 'mlp']:
m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
elif self._cfg.model.model_type and self._cfg.model.model_type == 'structure':
m_obs = value_obs_list[beg_index:end_index]
m_obs = sum(m_obs, [])
m_obs = default_collate(m_obs)
m_obs = to_device(m_obs, self._cfg.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抽象为一个数据处理函数,放在utils中?


# calculate the target value
m_output = model.initial_inference(m_obs)
Expand Down
5 changes: 4 additions & 1 deletion lzero/mcts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def prepare_observation(observation_list, model_type='conv'):
- observation_list (:obj:`List`): list of observations.
- model_type (:obj:`str`): type of the model. (default is 'conv')
"""
assert model_type in ['conv', 'mlp']
assert model_type in ['conv', 'mlp', 'structure']
observation_array = np.array(observation_list)

if model_type == 'conv':
Expand Down Expand Up @@ -127,6 +127,9 @@ def prepare_observation(observation_list, model_type='conv'):
observation_array = observation_array.reshape(observation_array.shape[0], -1)
# print(observation_array.shape)

elif model_type == 'structure':
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
return observation_list

return observation_array


Expand Down
19 changes: 13 additions & 6 deletions lzero/model/efficientzero_model_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch.nn as nn
from ding.torch_utils import MLP
from ding.utils import MODEL_REGISTRY, SequenceType
from ding.utils.default_helper import get_shape0

from numpy import ndarray

from .common import EZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP
Expand Down Expand Up @@ -36,6 +38,7 @@ def __init__(
norm_type: Optional[str] = 'BN',
discrete_action_encoding_type: str = 'one_hot',
res_connection_in_dynamics: bool = False,
state_encoder=None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -104,9 +107,12 @@ def __init__(
self.state_norm = state_norm
self.res_connection_in_dynamics = res_connection_in_dynamics

self.representation_network = RepresentationNetworkMLP(
observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type
)
if state_encoder == None:
self.representation_network = RepresentationNetworkMLP(
observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type
)
else:
self.representation_network = state_encoder

self.dynamics_network = DynamicsNetworkMLP(
action_encoding_dim=self.action_encoding_dim,
Expand Down Expand Up @@ -171,15 +177,16 @@ def initial_inference(self, obs: torch.Tensor) -> EZNetworkOutput:
- latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state.
- reward_hidden_state (:obj:`Tuple[torch.Tensor]`): The shape of each element is :math:`(1, B, lstm_hidden_size)`, where B is batch_size.
"""
batch_size = obs.size(0)
batch_size = get_shape0(obs)
latent_state = self._representation(obs)
device = latent_state.device
policy_logits, value = self._prediction(latent_state)
# zero initialization for reward hidden states
# (hn, cn), each element shape is (layer_num=1, batch_size, lstm_hidden_size)
reward_hidden_state = (
torch.zeros(1, batch_size,
self.lstm_hidden_size).to(obs.device), torch.zeros(1, batch_size,
self.lstm_hidden_size).to(obs.device)
self.lstm_hidden_size).to(device), torch.zeros(1, batch_size,
self.lstm_hidden_size).to(device)
)
return EZNetworkOutput(value, [0. for _ in range(batch_size)], policy_logits, latent_state, reward_hidden_state)

Expand Down
15 changes: 11 additions & 4 deletions lzero/model/muzero_model_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch.nn as nn
from ding.torch_utils import MLP
from ding.utils import MODEL_REGISTRY, SequenceType
from ding.utils.default_helper import get_shape0


from .common import MZNetworkOutput, RepresentationNetworkMLP, PredictionNetworkMLP
from .utils import renormalize, get_params_mean, get_dynamic_mean, get_reward_mean
Expand Down Expand Up @@ -34,6 +36,7 @@ def __init__(
discrete_action_encoding_type: str = 'one_hot',
norm_type: Optional[str] = 'BN',
res_connection_in_dynamics: bool = False,
state_encoder=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加state_encoder的Type Hints以及相应的arguments注释

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://aicarrier.feishu.cn/wiki/N4bqwLRO5iyQcAkb4HCcflbgnpR 可以参考这里的提示词优化注释哈

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

*args,
**kwargs
):
Expand Down Expand Up @@ -66,6 +69,7 @@ def __init__(
- discrete_action_encoding_type (:obj:`str`): The encoding type of discrete action, which can be 'one_hot' or 'not_one_hot'.
- norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
- res_connection_in_dynamics (:obj:`bool`): Whether to use residual connection for dynamics network, default set it to False.
- state_encoder (:obj:`Optional[nn.Module]`): The state encoder network, which is used to encode the raw observation to latent state.
"""
super(MuZeroModelMLP, self).__init__()
self.categorical_distribution = categorical_distribution
Expand Down Expand Up @@ -101,9 +105,12 @@ def __init__(
self.state_norm = state_norm
self.res_connection_in_dynamics = res_connection_in_dynamics

self.representation_network = RepresentationNetworkMLP(
observation_shape=observation_shape, hidden_channels=self.latent_state_dim, norm_type=norm_type
)
if state_encoder == None:
self.representation_network = RepresentationNetworkMLP(
observation_shape=observation_shape, hidden_channels=latent_state_dim, norm_type=norm_type
)
else:
self.representation_network = state_encoder

self.dynamics_network = DynamicsNetwork(
action_encoding_dim=self.action_encoding_dim,
Expand Down Expand Up @@ -166,7 +173,7 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput:
- policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size.
- latent_state (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch_size, H is the dimension of latent state.
"""
batch_size = obs.size(0)
batch_size = get_shape0(obs)
latent_state = self._representation(obs)
policy_logits, value = self._prediction(latent_state)
return MZNetworkOutput(
Expand Down
49 changes: 45 additions & 4 deletions lzero/policy/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, \
prepare_obs, \
configure_optimizers
<<<<<<< HEAD
from collections import defaultdict
from ding.torch_utils import to_device, to_tensor
from ding.utils.data import default_collate
=======
from lzero.policy.muzero import MuZeroPolicy
>>>>>>> origin


@POLICY_REGISTRY.register('efficientzero')
Expand Down Expand Up @@ -191,6 +197,9 @@ class EfficientZeroPolicy(MuZeroPolicy):
# (int) The decay steps from start to end eps.
decay=int(1e5),
),

# (bool) Whether it is a multi-agent environment.
multi_agent=False,
)

def default_model(self) -> Tuple[str, List[str]]:
Expand Down Expand Up @@ -309,7 +318,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:

target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1)
target_value = target_value.view(self._cfg.batch_size, -1)
assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0)
assert self._cfg.batch_size == target_value_prefix.size(0)

# ``scalar_transform`` to transform the original value to the scaled value,
# i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf.
Expand Down Expand Up @@ -395,9 +404,40 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
# calculate consistency loss for the next ``num_unroll_steps`` unroll steps.
# ==============================================================
if self._cfg.ssl_loss_weight > 0:
# obtain the oracle latent states from representation function.
beg_index, end_index = self._get_target_obs_index_in_step_k(step_k)
network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index])
# obtain the oracle hidden states from representation function.
if self._cfg.model.model_type == 'conv':
beg_index = self._cfg.model.image_channel * step_i
end_index = self._cfg.model.image_channel * (step_i + self._cfg.model.frame_stack_num)
network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index, :, :])
elif self._cfg.model.model_type == 'mlp':
beg_index = self._cfg.model.observation_shape * step_i
end_index = self._cfg.model.observation_shape * (step_i + self._cfg.model.frame_stack_num)
network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index])
elif self._cfg.model.model_type == 'structure':
obs_target_batch_new = {}
for k, v in obs_target_batch.items():
if k == 'action_mask':
obs_target_batch_new[k] = v
continue
if isinstance(v, dict):
obs_target_batch_new[k] = {}
for k1, v1 in v.items():
if len(v1.shape) == 1:
observation_shape = v1.shape[0]//self._cfg.num_unroll_steps
beg_index = observation_shape * step_i
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num)
obs_target_batch_new[k][k1] = v1[beg_index:end_index]
else:
observation_shape = v1.shape[1]//self._cfg.num_unroll_steps
beg_index = observation_shape * step_i
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num)
obs_target_batch_new[k][k1] = v1[:, beg_index:end_index]
else:
observation_shape = v.shape[1]//self._cfg.num_unroll_steps
beg_index = observation_shape * step_i
end_index = observation_shape * (step_i + self._cfg.model.frame_stack_num)
obs_target_batch_new[k] = v[:, beg_index:end_index]
network_output = self._learn_model.initial_inference(obs_target_batch_new)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面对结构化观察的处理或许可以抽象为一个函数


latent_state = to_tensor(latent_state)
representation_state = to_tensor(network_output.latent_state)
Expand Down Expand Up @@ -735,6 +775,7 @@ def _monitor_vars_learn(self) -> List[str]:
"""
return [
'collect_mcts_temperature',
'collect_epsilon',
'cur_lr',
'weighted_total_loss',
'total_loss',
Expand Down
Loading