Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add ExpertParallel with Top1 routing #49

Merged
merged 4 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions pipegoose/nn/expert_parallel/expert_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations
from typing import List

from torchtyping import TensorType


class ExpertContext:
_instance = None

def __init__(self):
self.aux_loss = []
self.z_loss = []

def push_aux_loss(self, aux_loss: TensorType):
self.aux_loss.append(aux_loss)

def pop_all_aux_loss(self) -> List[TensorType]:
aux_loss, self.aux_loss = self.aux_loss, []
return aux_loss

def push_z_loss(self, z_loss: TensorType):
self.z_loss.append(z_loss)

def pop_all_z_loss(self) -> List[TensorType]:
z_loss, self.z_loss = self.z_loss, []
return z_loss

@classmethod
def get_instance(cls) -> ExpertContext:
if not cls._instance:
cls._instance = ExpertContext()
return cls._instance
4 changes: 2 additions & 2 deletions pipegoose/nn/expert_parallel/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
router: Union[int, Callable] = 1,
# noise_poligy: Union[str, Callable],
enable_tensor_parallelism: bool = False,
parallel_context: ParallelContext = None,
parallel_context: ParallelContext = None
):
tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR)
assert parallel_context is not None, "parallel_context must be provided"
Expand Down Expand Up @@ -64,7 +64,7 @@ def parallelize(self) -> nn.Module:
module if self.expert is None else self.expert,
self.router,
self.enable_tensor_parallelism,
self.parallel_context,
self.parallel_context
)
getattr(self.module, "transformer").h[layer_idx].mlp = expert_layer

Expand Down
10 changes: 7 additions & 3 deletions pipegoose/nn/expert_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pipegoose.nn.expert_parallel.experts import Experts
from pipegoose.nn.expert_parallel.routers import Router
from pipegoose.nn.expert_parallel.utils import get_num_local_experts
from pipegoose.nn.expert_parallel.expert_context import ExpertContext


class ExpertLayer(nn.Module):
Expand All @@ -20,7 +21,7 @@ def __init__(
expert: nn.Module,
router: Router,
enable_tensor_parallel: bool,
parallel_context: ParallelContext,
parallel_context: ParallelContext
):
super().__init__()
self.router = router
Expand All @@ -39,6 +40,9 @@ def experts(self) -> nn.ModuleList:
def forward(self, *args, **kwargs) -> TensorType["batch_size", "seq_len", "d_model"]:
# TODO: use torch.fx to extract the inputs from args, and kwargs
inputs = args[0]
dispatching_order, _, _ = self.router(inputs)
outputs = self._experts(inputs, dispatching_order, *args, **kwargs)
router_output = self.router(inputs)
expert_context = ExpertContext.get_instance()
expert_context.push_aux_loss(router_output.aux_loss)
expert_context.push_z_loss(router_output.z_loss)
outputs = self._experts(inputs, router_output.dispatching_order, *args, **kwargs)
return outputs
17 changes: 12 additions & 5 deletions pipegoose/nn/expert_parallel/loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from typing import Callable

import torch
from torchtyping import TensorType

from pipegoose.nn.expert_parallel.expert_context import ExpertContext


class ExpertLoss:
def __init__(self, loss: Callable, aux_weight: float):
self.loss = loss
def __init__(self, loss_func: Callable, aux_weight: float, z_weight: float):
self.loss_func = loss_func
self.aux_weight = aux_weight
self.z_weight = z_weight

def __call__(self) -> torch.Tensor:
pass
def __call__(self, *args, **kwargs) -> TensorType:
loss = self.loss_func(*args, **kwargs)
expert_context = ExpertContext.get_instance()
loss += self.aux_weight * sum(expert_context.pop_all_aux_loss())
loss += self.z_weight * sum(expert_context.pop_all_z_loss())
return loss
32 changes: 24 additions & 8 deletions pipegoose/nn/expert_parallel/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
from torch import nn
from torchtyping import TensorType
from dataclasses import dataclass


class RouterExplorationNoisePolicy(ABC):
Expand All @@ -32,6 +33,14 @@ def sample_like(self, input: TensorType) -> TensorType:
return noise


@dataclass
class RouterOutput:
dispatching_order: TensorType["batch_size * seq_len", "num_experts"]
weight: TensorType["batch_size * seq_len", "num_experts"]
aux_loss: TensorType["1"]
z_loss: TensorType["1"]


class Router(ABC, nn.Module):
pass

Expand Down Expand Up @@ -93,9 +102,7 @@ def _expert_capacity(self, total_tokens: int) -> int:

def forward(
self, inputs: TensorType["batch_size", "seq_len", "d_model"]
) -> Tuple[
TensorType["batch_size*seq_len", "num_experts"], TensorType["batch_size*seq_len", "num_experts"], TensorType["1"]
]:
) -> RouterOutput:
orig_dtype = inputs.dtype
total_tokens = inputs.shape[0] * inputs.shape[1]

Expand All @@ -115,15 +122,19 @@ def forward(
topk_expert_mask = topk_expert_mask.scatter_(1, topk_idxs, True)

# calculate router loss
loss = self.aux_loss_weight * self._aux_loss(router_prob, topk_expert_mask) + self.z_loss_weight * self._z_loss(
router_logits
)
aux_loss = self._aux_loss(router_prob, topk_expert_mask)
z_loss = self._z_loss(router_logits)

if not self.expert_capacity:
# we don't limit the capacity of the experts
topk_weight = router_prob * topk_expert_mask
topk_weight = topk_weight.to(orig_dtype)
return topk_expert_mask, topk_weight, loss
return RouterOutput(
dispatching_order=topk_expert_mask,
weight=topk_weight,
aux_loss=aux_loss,
z_loss=z_loss
)

# limit the number of tokens per expert
position_in_expert = torch.cumsum(topk_expert_mask, dim=0) * topk_expert_mask
Expand All @@ -137,7 +148,12 @@ def forward(
topk_weight = router_prob * capacity_limited_topk_expert_mask
topk_weight = topk_weight.to(orig_dtype)

return capacity_limited_topk_expert_mask, topk_weight, loss
return RouterOutput(
dispatching_order=capacity_limited_topk_expert_mask,
weight=topk_weight,
aux_loss=aux_loss,
z_loss=z_loss
)


class Top1Router(_TopKRouter):
Expand Down
20 changes: 20 additions & 0 deletions tests/nn/expert_parallel/test_expert_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pipegoose.nn.expert_parallel.expert_context import ExpertContext


def test_expert_context():
expert_context = ExpertContext.get_instance()

expert_context.push_aux_loss(1.01)
expert_context.push_z_loss(2.01)

expert_context.push_aux_loss(1.02)
expert_context.push_z_loss(2.02)

# make sure that we have a singleton!
expert_context = ExpertContext.get_instance()

assert expert_context.pop_all_aux_loss() == [1.01, 1.02]
assert expert_context.pop_all_aux_loss() == []

assert expert_context.pop_all_z_loss() == [2.01, 2.02]
assert expert_context.pop_all_z_loss() == []
31 changes: 21 additions & 10 deletions tests/nn/expert_parallel/test_expert_loss.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
import torch
from torch import nn
import torch.nn.functional as F

from pipegoose.nn.expert_parallel import ExpertLoss
from pipegoose.nn.expert_parallel.expert_context import ExpertContext


def test_expert_loss():
loss_func = nn.CrossEntropyLoss()
torch.manual_seed(42)
logits = torch.randn((10, 5))
gt = torch.randn((10, 5))

expert_loss = ExpertLoss(loss_func, aux_weight=0.1)
loss_func = nn.MSELoss()

expert_loss = ExpertLoss(loss_func, aux_weight=0.1, z_weight=0.2)
expert_context = ExpertContext.get_instance()

assert expert_loss.aux_weight == 0.1
assert expert_loss.z_weight == 0.2
assert expert_loss.loss_func == loss_func

ExpertLoss.add_aux_loss(1.01)
ExpertLoss.add_z_loss(2.01)
expert_context.push_aux_loss(1.01)
expert_context.push_z_loss(2.01)

expert_context.push_aux_loss(1.02)
expert_context.push_z_loss(2.02)

assert expert_loss.get_aux_loss() == [1.01]
assert expert_loss.get_z_loss() == [2.01]
expected_loss = F.mse_loss(logits, gt) + 0.1 * (1.01 + 1.02) + 0.2 * (2.01 + 2.02)
loss = expert_loss(logits, gt)

ExpertLoss.add_aux_loss(1.02)
ExpertLoss.add_z_loss(2.02)
assert torch.allclose(loss, expected_loss)

assert expert_loss.get_aux_loss() == [1.01, 1.02]
assert expert_loss.get_z_loss() == [2.01, 2.02]
assert expert_context.aux_loss == []
assert expert_context.z_loss == []
Loading
Loading