mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-05-21 17:56:45 -04:00
Merge 20c8b4f6b1
into 4cc6253d5c
This commit is contained in:
commit
cbd58be7b7
@ -15,6 +15,7 @@ rank = 0
|
||||
block_size = 128
|
||||
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
||||
attn_impl: Literal["naive", "absorb"] = "absorb"
|
||||
moe_impl: Literal["mixture", "distribution"] = "distribution"
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
@ -595,6 +596,54 @@ class Gate(nn.Module):
|
||||
return weights.type_as(x), indices
|
||||
|
||||
|
||||
class Categorical(nn.Module):
|
||||
"""
|
||||
Stochastic Gating mechanism for routing inputs in an MoE model.
|
||||
|
||||
Attributes:
|
||||
dim (int): Dimensionality of input features.
|
||||
score_func (str): Scoring function ('softmax' or 'sigmoid').
|
||||
route_scale (float): Scaling factor for routing weights.
|
||||
weight (torch.nn.Parameter): Learnable weights for the gate.
|
||||
"""
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""
|
||||
Initializes the Gate module.
|
||||
|
||||
Args:
|
||||
args (ModelArgs): Model arguments containing gating parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
self.topk = args.n_activated_experts
|
||||
self.score_func = args.score_func
|
||||
self.route_scale = args.route_scale
|
||||
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass for the gating mechanism.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
Routing weights and selected expert indices.
|
||||
"""
|
||||
scores = linear(x, self.weight)
|
||||
if self.score_func == "softmax":
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
||||
else:
|
||||
scores = scores.sigmoid() / scores.sum(dim=-1, keepdim=True)
|
||||
# https://medium.com/@rjnclarke/6b602bae0f2c
|
||||
indices = torch.multinomial(
|
||||
scores, num_samples=self.topk, replacement=False)
|
||||
scores = scores / scores.gather(1, indices).sum(dim=1, keepdim=True) \
|
||||
if self.topk > 1 else scores
|
||||
return (scores * self.route_scale).type_as(x), indices
|
||||
|
||||
|
||||
class Expert(nn.Module):
|
||||
"""
|
||||
Expert layer for Mixture-of-Experts (MoE) models.
|
||||
@ -658,7 +707,7 @@ class MoE(nn.Module):
|
||||
self.n_activated_experts = args.n_activated_experts
|
||||
self.experts_start_idx = rank * self.n_local_experts
|
||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||
self.gate = Gate(args)
|
||||
self.gate = Gate(args) if moe_impl == "mixture" else Categorical(args)
|
||||
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
||||
for i in range(self.n_routed_experts)])
|
||||
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
||||
|
Loading…
Reference in New Issue
Block a user