Skip to content

Commit

Permalink
Merge pull request #49 from xrsrke/feature/moe
Browse files Browse the repository at this point in the history
[Feature] Add ExpertParallel with Top1 routing
  • Loading branch information
xrsrke authored Nov 29, 2023
2 parents 6f6cdfd + c28fc89 commit 93dfb32
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 94 deletions.
32 changes: 32 additions & 0 deletions pipegoose/nn/expert_parallel/expert_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations
from typing import List

from torchtyping import TensorType


class ExpertContext:
_instance = None

def __init__(self):
self.aux_loss = []
self.z_loss = []

def push_aux_loss(self, aux_loss: TensorType):
self.aux_loss.append(aux_loss)

def pop_all_aux_loss(self) -> List[TensorType]:
aux_loss, self.aux_loss = self.aux_loss, []
return aux_loss

def push_z_loss(self, z_loss: TensorType):
self.z_loss.append(z_loss)

def pop_all_z_loss(self) -> List[TensorType]:
z_loss, self.z_loss = self.z_loss, []
return z_loss

@classmethod
def get_instance(cls) -> ExpertContext:
if not cls._instance:
cls._instance = ExpertContext()
return cls._instance
4 changes: 2 additions & 2 deletions pipegoose/nn/expert_parallel/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
router: Union[int, Callable] = 1,
# noise_poligy: Union[str, Callable],
enable_tensor_parallelism: bool = False,
parallel_context: ParallelContext = None,
parallel_context: ParallelContext = None
):
tensor_parallel_size = parallel_context.get_world_size(ParallelMode.TENSOR)
assert parallel_context is not None, "parallel_context must be provided"
Expand Down Expand Up @@ -64,7 +64,7 @@ def parallelize(self) -> nn.Module:
module if self.expert is None else self.expert,
self.router,
self.enable_tensor_parallelism,
self.parallel_context,
self.parallel_context
)
getattr(self.module, "transformer").h[layer_idx].mlp = expert_layer

Expand Down
10 changes: 7 additions & 3 deletions pipegoose/nn/expert_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pipegoose.nn.expert_parallel.experts import Experts
from pipegoose.nn.expert_parallel.routers import Router
from pipegoose.nn.expert_parallel.utils import get_num_local_experts
from pipegoose.nn.expert_parallel.expert_context import ExpertContext


class ExpertLayer(nn.Module):
Expand All @@ -20,7 +21,7 @@ def __init__(
expert: nn.Module,
router: Router,
enable_tensor_parallel: bool,
parallel_context: ParallelContext,
parallel_context: ParallelContext
):
super().__init__()
self.router = router
Expand All @@ -39,6 +40,9 @@ def experts(self) -> nn.ModuleList:
def forward(self, *args, **kwargs) -> TensorType["batch_size", "seq_len", "d_model"]:
# TODO: use torch.fx to extract the inputs from args, and kwargs
inputs = args[0]
dispatching_order, _, _ = self.router(inputs)
outputs = self._experts(inputs, dispatching_order, *args, **kwargs)
router_output = self.router(inputs)
expert_context = ExpertContext.get_instance()
expert_context.push_aux_loss(router_output.aux_loss)
expert_context.push_z_loss(router_output.z_loss)
outputs = self._experts(inputs, router_output.dispatching_order, *args, **kwargs)
return outputs
17 changes: 12 additions & 5 deletions pipegoose/nn/expert_parallel/loss.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from typing import Callable

import torch
from torchtyping import TensorType

from pipegoose.nn.expert_parallel.expert_context import ExpertContext


class ExpertLoss:
def __init__(self, loss: Callable, aux_weight: float):
self.loss = loss
def __init__(self, loss_func: Callable, aux_weight: float, z_weight: float):
self.loss_func = loss_func
self.aux_weight = aux_weight
self.z_weight = z_weight

def __call__(self) -> torch.Tensor:
pass
def __call__(self, *args, **kwargs) -> TensorType:
loss = self.loss_func(*args, **kwargs)
expert_context = ExpertContext.get_instance()
loss += self.aux_weight * sum(expert_context.pop_all_aux_loss())
loss += self.z_weight * sum(expert_context.pop_all_z_loss())
return loss
32 changes: 24 additions & 8 deletions pipegoose/nn/expert_parallel/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional as F
from torch import nn
from torchtyping import TensorType
from dataclasses import dataclass


class RouterExplorationNoisePolicy(ABC):
Expand All @@ -32,6 +33,14 @@ def sample_like(self, input: TensorType) -> TensorType:
return noise


@dataclass
class RouterOutput:
dispatching_order: TensorType["batch_size * seq_len", "num_experts"]
weight: TensorType["batch_size * seq_len", "num_experts"]
aux_loss: TensorType["1"]
z_loss: TensorType["1"]


class Router(ABC, nn.Module):
pass

Expand Down Expand Up @@ -93,9 +102,7 @@ def _expert_capacity(self, total_tokens: int) -> int:

def forward(
self, inputs: TensorType["batch_size", "seq_len", "d_model"]
) -> Tuple[
TensorType["batch_size*seq_len", "num_experts"], TensorType["batch_size*seq_len", "num_experts"], TensorType["1"]
]:
) -> RouterOutput:
orig_dtype = inputs.dtype
total_tokens = inputs.shape[0] * inputs.shape[1]

Expand All @@ -115,15 +122,19 @@ def forward(
topk_expert_mask = topk_expert_mask.scatter_(1, topk_idxs, True)

# calculate router loss
loss = self.aux_loss_weight * self._aux_loss(router_prob, topk_expert_mask) + self.z_loss_weight * self._z_loss(
router_logits
)
aux_loss = self._aux_loss(router_prob, topk_expert_mask)
z_loss = self._z_loss(router_logits)

if not self.expert_capacity:
# we don't limit the capacity of the experts
topk_weight = router_prob * topk_expert_mask
topk_weight = topk_weight.to(orig_dtype)
return topk_expert_mask, topk_weight, loss
return RouterOutput(
dispatching_order=topk_expert_mask,
weight=topk_weight,
aux_loss=aux_loss,
z_loss=z_loss
)

# limit the number of tokens per expert
position_in_expert = torch.cumsum(topk_expert_mask, dim=0) * topk_expert_mask
Expand All @@ -137,7 +148,12 @@ def forward(
topk_weight = router_prob * capacity_limited_topk_expert_mask
topk_weight = topk_weight.to(orig_dtype)

return capacity_limited_topk_expert_mask, topk_weight, loss
return RouterOutput(
dispatching_order=capacity_limited_topk_expert_mask,
weight=topk_weight,
aux_loss=aux_loss,
z_loss=z_loss
)


class Top1Router(_TopKRouter):
Expand Down
20 changes: 20 additions & 0 deletions tests/nn/expert_parallel/test_expert_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pipegoose.nn.expert_parallel.expert_context import ExpertContext


def test_expert_context():
expert_context = ExpertContext.get_instance()

expert_context.push_aux_loss(1.01)
expert_context.push_z_loss(2.01)

expert_context.push_aux_loss(1.02)
expert_context.push_z_loss(2.02)

# make sure that we have a singleton!
expert_context = ExpertContext.get_instance()

assert expert_context.pop_all_aux_loss() == [1.01, 1.02]
assert expert_context.pop_all_aux_loss() == []

assert expert_context.pop_all_z_loss() == [2.01, 2.02]
assert expert_context.pop_all_z_loss() == []
31 changes: 21 additions & 10 deletions tests/nn/expert_parallel/test_expert_loss.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
import torch
from torch import nn
import torch.nn.functional as F

from pipegoose.nn.expert_parallel import ExpertLoss
from pipegoose.nn.expert_parallel.expert_context import ExpertContext


def test_expert_loss():
loss_func = nn.CrossEntropyLoss()
torch.manual_seed(42)
logits = torch.randn((10, 5))
gt = torch.randn((10, 5))

expert_loss = ExpertLoss(loss_func, aux_weight=0.1)
loss_func = nn.MSELoss()

expert_loss = ExpertLoss(loss_func, aux_weight=0.1, z_weight=0.2)
expert_context = ExpertContext.get_instance()

assert expert_loss.aux_weight == 0.1
assert expert_loss.z_weight == 0.2
assert expert_loss.loss_func == loss_func

ExpertLoss.add_aux_loss(1.01)
ExpertLoss.add_z_loss(2.01)
expert_context.push_aux_loss(1.01)
expert_context.push_z_loss(2.01)

expert_context.push_aux_loss(1.02)
expert_context.push_z_loss(2.02)

assert expert_loss.get_aux_loss() == [1.01]
assert expert_loss.get_z_loss() == [2.01]
expected_loss = F.mse_loss(logits, gt) + 0.1 * (1.01 + 1.02) + 0.2 * (2.01 + 2.02)
loss = expert_loss(logits, gt)

ExpertLoss.add_aux_loss(1.02)
ExpertLoss.add_z_loss(2.02)
assert torch.allclose(loss, expected_loss)

assert expert_loss.get_aux_loss() == [1.01, 1.02]
assert expert_loss.get_z_loss() == [2.01, 2.02]
assert expert_context.aux_loss == []
assert expert_context.z_loss == []
Loading

0 comments on commit 93dfb32

Please sign in to comment.