mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-05-22 02:06:45 -04:00
stochastic gating
This commit is contained in:
parent
4cc6253d5c
commit
20c8b4f6b1
@ -15,6 +15,7 @@ rank = 0
|
|||||||
block_size = 128
|
block_size = 128
|
||||||
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
||||||
attn_impl: Literal["naive", "absorb"] = "absorb"
|
attn_impl: Literal["naive", "absorb"] = "absorb"
|
||||||
|
moe_impl: Literal["mixture", "distribution"] = "distribution"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs:
|
class ModelArgs:
|
||||||
@ -595,6 +596,54 @@ class Gate(nn.Module):
|
|||||||
return weights.type_as(x), indices
|
return weights.type_as(x), indices
|
||||||
|
|
||||||
|
|
||||||
|
class Categorical(nn.Module):
|
||||||
|
"""
|
||||||
|
Stochastic Gating mechanism for routing inputs in an MoE model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dim (int): Dimensionality of input features.
|
||||||
|
score_func (str): Scoring function ('softmax' or 'sigmoid').
|
||||||
|
route_scale (float): Scaling factor for routing weights.
|
||||||
|
weight (torch.nn.Parameter): Learnable weights for the gate.
|
||||||
|
"""
|
||||||
|
def __init__(self, args: ModelArgs):
|
||||||
|
"""
|
||||||
|
Initializes the Gate module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (ModelArgs): Model arguments containing gating parameters.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = args.dim
|
||||||
|
self.topk = args.n_activated_experts
|
||||||
|
self.score_func = args.score_func
|
||||||
|
self.route_scale = args.route_scale
|
||||||
|
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass for the gating mechanism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
Routing weights and selected expert indices.
|
||||||
|
"""
|
||||||
|
scores = linear(x, self.weight)
|
||||||
|
if self.score_func == "softmax":
|
||||||
|
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
scores = scores.sigmoid() / scores.sum(dim=-1, keepdim=True)
|
||||||
|
# https://medium.com/@rjnclarke/6b602bae0f2c
|
||||||
|
indices = torch.multinomial(
|
||||||
|
scores, num_samples=self.topk, replacement=False)
|
||||||
|
scores = scores / scores.gather(1, indices).sum(dim=1, keepdim=True) \
|
||||||
|
if self.topk > 1 else scores
|
||||||
|
return (scores * self.route_scale).type_as(x), indices
|
||||||
|
|
||||||
|
|
||||||
class Expert(nn.Module):
|
class Expert(nn.Module):
|
||||||
"""
|
"""
|
||||||
Expert layer for Mixture-of-Experts (MoE) models.
|
Expert layer for Mixture-of-Experts (MoE) models.
|
||||||
@ -658,7 +707,7 @@ class MoE(nn.Module):
|
|||||||
self.n_activated_experts = args.n_activated_experts
|
self.n_activated_experts = args.n_activated_experts
|
||||||
self.experts_start_idx = rank * self.n_local_experts
|
self.experts_start_idx = rank * self.n_local_experts
|
||||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||||
self.gate = Gate(args)
|
self.gate = Gate(args) if moe_impl == "mixture" else Categorical(args)
|
||||||
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
||||||
for i in range(self.n_routed_experts)])
|
for i in range(self.n_routed_experts)])
|
||||||
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
||||||
|
Loading…
Reference in New Issue
Block a user