diff --git a/inference/model.py b/inference/model.py index c143e97..8b42009 100644 --- a/inference/model.py +++ b/inference/model.py @@ -15,6 +15,7 @@ rank = 0 block_size = 128 gemm_impl: Literal["bf16", "fp8"] = "bf16" attn_impl: Literal["naive", "absorb"] = "absorb" +moe_impl: Literal["mixture", "distribution"] = "distribution" @dataclass class ModelArgs: @@ -595,6 +596,54 @@ class Gate(nn.Module): return weights.type_as(x), indices +class Categorical(nn.Module): + """ + Stochastic Gating mechanism for routing inputs in an MoE model. + + Attributes: + dim (int): Dimensionality of input features. + score_func (str): Scoring function ('softmax' or 'sigmoid'). + route_scale (float): Scaling factor for routing weights. + weight (torch.nn.Parameter): Learnable weights for the gate. + """ + def __init__(self, args: ModelArgs): + """ + Initializes the Gate module. + + Args: + args (ModelArgs): Model arguments containing gating parameters. + """ + super().__init__() + self.dim = args.dim + self.topk = args.n_activated_experts + self.score_func = args.score_func + self.route_scale = args.route_scale + self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the gating mechanism. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + Routing weights and selected expert indices. + """ + scores = linear(x, self.weight) + if self.score_func == "softmax": + scores = scores.softmax(dim=-1, dtype=torch.float32) + else: + scores = scores.sigmoid() / scores.sum(dim=-1, keepdim=True) + # https://medium.com/@rjnclarke/6b602bae0f2c + indices = torch.multinomial( + scores, num_samples=self.topk, replacement=False) + scores = scores / scores.gather(1, indices).sum(dim=1, keepdim=True) \ + if self.topk > 1 else scores + return (scores * self.route_scale).type_as(x), indices + + class Expert(nn.Module): """ Expert layer for Mixture-of-Experts (MoE) models. @@ -658,7 +707,7 @@ class MoE(nn.Module): self.n_activated_experts = args.n_activated_experts self.experts_start_idx = rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts - self.gate = Gate(args) + self.gate = Gate(args) if moe_impl == "mixture" else Categorical(args) self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None for i in range(self.n_routed_experts)]) self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)