Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wrh): update soft modulization in unizero for mt #250

Open
wants to merge 5 commits into
base: dev-unizero-multitask-v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def train_unizero(
game_buffer_classes[create_cfg.policy.type])

# Set device based on CUDA availability
cfg.policy.device = cfg.policy.model.world_model.device if torch.cuda.is_available() else 'cpu'
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f'cfg.policy.device: {cfg.policy.device}')

# Compile the configuration
Expand Down
199 changes: 197 additions & 2 deletions lzero/model/unizero_world_models/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hashlib
from typing import Optional, List, Tuple, Union
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .kv_caching import KeysValues

Expand Down Expand Up @@ -95,6 +97,20 @@ class WorldModelOutput:
logits_value: torch.FloatTensor


@dataclass
class WorldModelOutputSoftModulization:
output_sequence: torch.FloatTensor
logits_observations: torch.FloatTensor
logits_rewards: torch.FloatTensor
logits_ends: torch.FloatTensor
logits_policy: torch.FloatTensor
logits_value: torch.FloatTensor
task_id: int
observation_weights_list: List[torch.FloatTensor] = None
reward_weights_list: List[torch.FloatTensor] = None
policy_weights_list: List[torch.FloatTensor] = None
value_weights_list: List[torch.FloatTensor] = None

def init_weights(module, norm_type='BN'):
"""
Initialize the weights of the module based on the specified normalization type.
Expand Down Expand Up @@ -162,7 +178,11 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg
self.value_loss_weight = 0.25
self.policy_loss_weight = 1.
self.ends_loss_weight = 0.


# updated from soft modulization
self.task_id = kwargs.get("task_id", None)
self.obs_soft_module_route_weights = kwargs.get("observation_weights_list", None)

self.latent_recon_loss_weight = latent_recon_loss_weight
self.perceptual_loss_weight = perceptual_loss_weight

Expand All @@ -186,11 +206,186 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg

self.intermediate_losses = {
k: v if isinstance(v, dict) or isinstance(v, np.ndarray) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item())
for k, v in kwargs.items()
for k, v in kwargs.items() if k not in ["task_id", "observation_weights_list"]
}

def __truediv__(self, value):
for k, v in self.intermediate_losses.items():
self.intermediate_losses[k] = v / value
self.loss_total = self.loss_total / value
return self


class SoftModulizationHead(nn.Module):
"""
Overview:
SoftModulizationHead is an nn.Module class that implements soft modulization for multi-task reinforcement learning.

Arguments:
- task_num (:obj:`int`): The number of tasks.
- embed_dim (:obj:`int`): The embedding dimension.
- gating_embed_mlp_num (:obj:`int`): The number of layers in the MLP for gating embeddings.
- base_model_modulelists (:obj:`ModuleList[ModuleList[nn.Module]]`): A list of lists including base model modules.
- base_layers_num (:obj:`int`): The number of base layers in the model.
- base_modules_num (:obj:`int`): The number of base modules.
- device (:obj:`torch.device`): The device to run computations on.

"""

def __init__(self,
task_num: int,
embed_dim: int,
gating_embed_mlp_num: int,
base_model_modulelists,
base_layers_num: int = 3,
base_modules_num: int = 4,
device: torch.device = torch.device("cpu")
) -> None:
super(SoftModulizationHead, self).__init__()
self.task_num = task_num
self.embed_dim = embed_dim
self.gating_embed_mlp_num = gating_embed_mlp_num
self.base_layers_num = base_layers_num
self.base_modules_num = base_modules_num
self.base_model_modulelists = base_model_modulelists
self.device = device

# Task embedding layer
self.task_embed_layer = nn.Linear(task_num, embed_dim)
# Ex. (.., task_num) -> (.., 768)

# Gating fully connected layers
gating_fc_layer_module = [nn.ReLU(), nn.Linear(embed_dim, embed_dim)] * (self.gating_embed_mlp_num - 1)
self.gating_fcs = nn.Sequential(*gating_fc_layer_module)

# Initial gating weight layer
self.gating_weight_fc_0 = nn.Linear(embed_dim, self.base_modules_num * self.base_modules_num)
with torch.no_grad():
self.gating_weight_fc_0.weight.zero_()
bias_vector = 6 * torch.eye(self.base_modules_num, dtype=torch.float32).reshape(-1)
self.gating_weight_fc_0.bias = nn.Parameter(bias_vector)
# (.., 768) -> (.., 16)

# Conditional gating weight layers
self.gating_weight_fcs = nn.ModuleList()
self.gating_weight_cond_fcs = nn.ModuleList()
for k in range(self.base_layers_num - 2):
gating_weight_cond_fc_layer = nn.Linear((k + 1) * self.base_modules_num * self.base_modules_num, embed_dim)
self.gating_weight_cond_fcs.append(gating_weight_cond_fc_layer)
gating_weight_fc_layer = nn.Linear(embed_dim, self.base_modules_num * self.base_modules_num)
with torch.no_grad():
gating_weight_fc_layer.weight.zero_()
bias_vector = 6 * torch.eye(self.base_modules_num, dtype=torch.float32).reshape(-1)
gating_weight_fc_layer.bias = nn.Parameter(bias_vector)
self.gating_weight_fcs.append(gating_weight_fc_layer)

# Cond_weight_fcs [Linear(16, 768), Linear(32, 768), Linear(48, 768)]
# weight_fcs: [Linear]

# Final gating weight layers
self.gating_weight_cond_last = nn.Linear((self.base_layers_num - 1) * self.base_modules_num * self.base_modules_num, embed_dim)
self.gating_weight_last_fc = nn.Linear(embed_dim, self.base_modules_num)

def forward(self, x: torch.Tensor, task_id: int,
final_norm: Optional[nn.Module]=None, return_weight: bool=False
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""
Overview:
Forward pass for soft modulization.

Arguments:
- x (:obj:`torch.Tensor`): Input tensor.
- task_id (:obj:`int`): ID of the task.
- final_norm (:obj:`Optional[nn.Module]`): Optional normalization layer to be applied at the end.
- return_weight (:obj:`bool`): Flag indicating whether to return the weights along with the output.

Returns:
- Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: Output tensor after soft modulization,
or a tuple containing the output tensor and a list of weights if `return_weight` is True.

Example:

"""
# print(f"x.shape: {x.shape}")
task_id_vector = torch.zeros(self.task_num).to(self.device)
task_id_vector[task_id] = 1

# Process task embedding
task_embedding = self.task_embed_layer(task_id_vector).to(self.device)
# print(f"task_embedding.shape before * x : {task_embedding.shape}")
task_embedding = F.relu(task_embedding * x)
task_embedding = self.gating_fcs(task_embedding)
# print(f"task_embedding.shape after * x: {task_embedding.shape}")

weights = []
flatten_weights = []
base_shape = task_embedding.shape[:-1]
weight_shape = base_shape + torch.Size([self.base_modules_num, self.base_modules_num])
flatten_shape = base_shape + torch.Size([self.base_modules_num ** 2])

# Calculate weights between layers
raw_weight = self.gating_weight_fc_0(F.relu(task_embedding))
raw_weight = raw_weight.view(weight_shape)


softmax_weight = F.softmax(raw_weight, dim=-1)
# print(f"softmax_weight: {softmax_weight.shape}")
flatten_weight = softmax_weight.view(flatten_shape)
# print(f"flatten_weight: {flatten_weight.shape}")
weights.append(softmax_weight)
flatten_weights.append(flatten_weight)

for i, (gating_weight_fc, gating_weight_cond_fc) in enumerate(zip(self.gating_weight_fcs, self.gating_weight_cond_fcs)):
cond = F.relu(torch.cat(flatten_weights, dim=-1))
# print(f"cond_weight: {cond.shape}")
cond = gating_weight_cond_fc(cond)
# print(f"cond_weight: {cond.shape}")
cond = F.relu(cond * task_embedding)
# print(f"cond_weight: {cond.shape}")

raw_weight = gating_weight_fc(cond)
raw_weight = raw_weight.view(weight_shape)
softmax_weight = F.softmax(raw_weight, dim=-1)
flatten_weight = softmax_weight.view(flatten_shape)
weights.append(softmax_weight)
flatten_weights.append(flatten_weight)

cond = F.relu(torch.cat(flatten_weights, dim=-1))
# print(f"cond_weight: {cond.shape}")
cond = self.gating_weight_cond_last(cond)
cond = F.relu(cond * task_embedding)
# print(f"cond_weight: {cond.shape}")
raw_last_weight = self.gating_weight_last_fc(cond)
# print(f"cond_weight: {cond.shape}")
last_weight = F.softmax(raw_last_weight, dim=-1)

# Forward calculation
# print(f"self.base_modules_num = {self.base_modules_num}")
# print(f"len(self.base_model) = {len(self.base_model_modulelists)}")
# print(f"len(self.base_model[0]) = {len(self.base_model_modulelists[0])}")
obs_mid_layers = [self.base_model_modulelists[0][i] for i in range(self.base_modules_num)]
obs_mid_outputs = [obs_mid_layer(x).unsqueeze(-2) for obs_mid_layer in obs_mid_layers]
obs_mid_outputs = torch.cat(obs_mid_outputs, dim=-2)

for i in range(self.base_layers_num - 1):
new_module_outputs = []
obs_next_mid_layers = [self.base_model_modulelists[i+1][j] for j in range(self.base_modules_num)]

for j, next_layer_module in enumerate(obs_next_mid_layers):

# print(f"obs_mid_outputs.shape: {obs_mid_outputs.shape}")
# print(f"weights[{i}][..., {j}, :].unsqueeze(-1).shape: {weights[i][..., j, :].unsqueeze(-1).shape}")
next_module_input = F.relu((obs_mid_outputs * weights[i][..., j, :].unsqueeze(-1)).sum(dim=-2))
new_module_outputs.append((next_layer_module(next_module_input)).unsqueeze(-2))
# print([x.shape for x in new_module_outputs])
obs_mid_outputs = torch.cat(new_module_outputs, dim=-2)

obs_module_output = obs_mid_outputs
obs_output = (obs_module_output * last_weight.unsqueeze(-1)).sum(-2)

if final_norm is not None:
obs_output = final_norm(obs_output)

if return_weight:
return obs_output, flatten_weights
return obs_output
Loading