stochastic gating

This commit is contained in:
Kent Slaney 2025-05-06 21:38:27 -07:00
parent 4cc6253d5c
commit 20c8b4f6b1

View File

@ -15,6 +15,7 @@ rank = 0
block_size = 128 block_size = 128
gemm_impl: Literal["bf16", "fp8"] = "bf16" gemm_impl: Literal["bf16", "fp8"] = "bf16"
attn_impl: Literal["naive", "absorb"] = "absorb" attn_impl: Literal["naive", "absorb"] = "absorb"
moe_impl: Literal["mixture", "distribution"] = "distribution"
@dataclass @dataclass
class ModelArgs: class ModelArgs:
@ -595,6 +596,54 @@ class Gate(nn.Module):
return weights.type_as(x), indices return weights.type_as(x), indices
class Categorical(nn.Module):
"""
Stochastic Gating mechanism for routing inputs in an MoE model.
Attributes:
dim (int): Dimensionality of input features.
score_func (str): Scoring function ('softmax' or 'sigmoid').
route_scale (float): Scaling factor for routing weights.
weight (torch.nn.Parameter): Learnable weights for the gate.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Gate module.
Args:
args (ModelArgs): Model arguments containing gating parameters.
"""
super().__init__()
self.dim = args.dim
self.topk = args.n_activated_experts
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for the gating mechanism.
Args:
x (torch.Tensor): Input tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
Routing weights and selected expert indices.
"""
scores = linear(x, self.weight)
if self.score_func == "softmax":
scores = scores.softmax(dim=-1, dtype=torch.float32)
else:
scores = scores.sigmoid() / scores.sum(dim=-1, keepdim=True)
# https://medium.com/@rjnclarke/6b602bae0f2c
indices = torch.multinomial(
scores, num_samples=self.topk, replacement=False)
scores = scores / scores.gather(1, indices).sum(dim=1, keepdim=True) \
if self.topk > 1 else scores
return (scores * self.route_scale).type_as(x), indices
class Expert(nn.Module): class Expert(nn.Module):
""" """
Expert layer for Mixture-of-Experts (MoE) models. Expert layer for Mixture-of-Experts (MoE) models.
@ -658,7 +707,7 @@ class MoE(nn.Module):
self.n_activated_experts = args.n_activated_experts self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args) self.gate = Gate(args) if moe_impl == "mixture" else Categorical(args)
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)]) for i in range(self.n_routed_experts)])
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)