diff --git a/inference/model.py b/inference/model.py index 9ea60c9..b2f0e7f 100644 --- a/inference/model.py +++ b/inference/model.py @@ -8,6 +8,19 @@ import torch.nn.functional as F import torch.distributed as dist from kernel import act_quant, weight_dequant, fp8_gemm +from flash_attn import flash_attn_func + +# Using float16 (FP16) or bfloat16 (BF16) speeds up computations without losing much accuracy. +from torch.cuda.amp import autocast, GradScaler +scaler = GradScaler() + +with autocast(): + logits = model(tokens) + loss = loss_fn(logits, targets) + +scaler.scale(loss).backward() +scaler.step(optimizer) +scaler.update() world_size = 1 @@ -95,12 +108,12 @@ class ParallelEmbedding(nn.Module): def __init__(self, vocab_size: int, dim: int): super().__init__() self.vocab_size = vocab_size - self.dim = dim + self.dim = dim assert vocab_size % world_size == 0 self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size - self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size - self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) + self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size + self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -117,7 +130,7 @@ class ParallelEmbedding(nn.Module): """ if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) - x = x - self.vocab_start_idx + x = x - self.vocab_start_idx x[mask] = 0 y = F.embedding(x, self.weight) if world_size > 1: @@ -175,13 +188,13 @@ class Linear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): super().__init__() - self.in_features = in_features + self.in_features = in_features self.out_features = out_features - self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) if self.weight.element_size() == 1: scale_out_features = (out_features + block_size - 1) // block_size - scale_in_features = (in_features + block_size - 1) // block_size - self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) + scale_in_features = (in_features + block_size - 1) // block_size + self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) else: self.register_parameter("scale", None) if bias: @@ -265,174 +278,90 @@ class RowParallelLinear(Linear): class RMSNorm(nn.Module): - """ - Root Mean Square Layer Normalization (RMSNorm). + """Root Mean Square Layer Normalization (RMSNorm).""" - Args: - dim (int): Dimension of the input tensor. - eps (float): Epsilon value for numerical stability. Defaults to 1e-6. - """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() - self.dim = dim - self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps def forward(self, x: torch.Tensor): - """ - Forward pass for RMSNorm. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Normalized tensor with the same shape as input. - """ - return F.rms_norm(x, (self.dim,), self.weight, self.eps) + return F.rms_norm(x, (x.shape[-1],), self.weight, self.eps) -def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: - """ - Precomputes frequency-based complex exponential values for rotary positional embeddings. +def precompute_freqs_cis(args): + """Precomputes frequency-based complex exponential values for rotary embeddings.""" + dim, seqlen, base, factor = args.qk_rope_head_dim, args.max_seq_len, args.rope_theta, args.rope_factor + beta_fast, beta_slow, orig_seq_len = args.beta_fast, args.beta_slow, args.original_seq_len - Args: - args (ModelArgs): Model arguments containing positional embedding parameters. + log_base = 2 * math.log(base) + inv_dim = 1 / dim # Precompute inverse to avoid division in tensor ops - Returns: - torch.Tensor: Precomputed complex exponential values for positional embeddings. - """ - dim = args.qk_rope_head_dim - seqlen = args.max_seq_len - beta_fast = args.beta_fast - beta_slow = args.beta_slow - base = args.rope_theta - factor = args.rope_factor + def correction_dim(rotations): + return dim * (math.log(orig_seq_len / (rotations * 2 * math.pi)) / log_base) - def find_correction_dim(num_rotations, dim, base, max_seq_len): - """ - Computes the correction dimension for a given number of rotations in the rotary positional embedding. + def correction_range(low_rot, high_rot): + return torch.clamp(torch.round(torch.tensor([correction_dim(low_rot), correction_dim(high_rot)])), 0, dim - 1).int() - Args: - num_rotations (float): Number of rotations to compute the correction for. - dim (int): Dimensionality of the embedding space. - base (float): Base value for the exponential computation. - max_seq_len (int): Maximum sequence length. + def linear_ramp(min_v, max_v, size): + return torch.clamp((torch.arange(size, dtype=torch.float32) - min_v) / (max_v - min_v + 1e-3), 0, 1) - Returns: - float: The correction dimension based on the input parameters. - """ - return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) + freqs = (base ** (-torch.arange(0, dim, 2, dtype=torch.float32) * inv_dim)).reciprocal() - def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): - """ - Computes the range of correction dimensions for rotary positional embeddings. + if seqlen > orig_seq_len: + low, high = correction_range(beta_fast, beta_slow) + smooth = 1 - linear_ramp(low, high, dim // 2) + freqs = freqs * (smooth + (1 - smooth) / factor) - Args: - low_rot (float): Lower bound for the number of rotations. - high_rot (float): Upper bound for the number of rotations. - dim (int): Dimensionality of the embedding space. - base (float): Base value for the exponential computation. - max_seq_len (int): Maximum sequence length. - - Returns: - Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. - """ - low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) - return max(low, 0), min(high, dim-1) - - def linear_ramp_factor(min, max, dim): - """ - Computes a linear ramp function used to smooth values between a minimum and maximum range. - - Args: - min (float): Minimum value for the ramp function. - max (float): Maximum value for the ramp function. - dim (int): Dimensionality of the ramp tensor. - - Returns: - torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, - clamped to the range [0, 1]. - """ - if min == max: - max += 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - if seqlen > args.original_seq_len: - low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) - smooth = 1 - linear_ramp_factor(low, high, dim // 2) - freqs = freqs / factor * (1 - smooth) + freqs * smooth - - t = torch.arange(seqlen) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) - return freqs_cis + return torch.polar(torch.ones((seqlen, len(freqs))), torch.outer(torch.arange(seqlen), freqs)) def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """ - Applies rotary positional embeddings to the input tensor. - - Args: - x (torch.Tensor): Input tensor with positional embeddings to be applied. - freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. - - Returns: - torch.Tensor: Tensor with rotary embeddings applied. - """ + """Applies rotary positional embeddings to the input tensor.""" dtype = x.dtype - x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) - y = torch.view_as_real(x * freqs_cis).flatten(3) - return y.to(dtype) + x = torch.view_as_complex(x.float().reshape_as(x[..., :-1:2])) # Reshape for complex conversion + return torch.view_as_real(x * freqs_cis.expand_as(x)).flatten(-2).to(dtype) class MLA(nn.Module): """ - Multi-Headed Attention Layer (MLA). - - Attributes: - dim (int): Dimensionality of the input features. - n_heads (int): Number of attention heads. - n_local_heads (int): Number of local attention heads for distributed systems. - q_lora_rank (int): Rank for low-rank query projection. - kv_lora_rank (int): Rank for low-rank key/value projection. - qk_nope_head_dim (int): Dimensionality of non-positional query/key projections. - qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections. - qk_head_dim (int): Total dimensionality of query/key projections. - v_head_dim (int): Dimensionality of value projections. - softmax_scale (float): Scaling factor for softmax in attention computation. + Multi-Headed Attention Layer (MLA) with optional low-rank LoRA projections. """ def __init__(self, args: ModelArgs): super().__init__() - self.dim = args.dim - self.n_heads = args.n_heads - self.n_local_heads = args.n_heads // world_size - self.q_lora_rank = args.q_lora_rank - self.kv_lora_rank = args.kv_lora_rank + self.dim = args.dim + self.n_heads = args.n_heads + self.n_local_heads = args.n_heads // world_size + self.q_lora_rank = args.q_lora_rank + self.kv_lora_rank = args.kv_lora_rank self.qk_nope_head_dim = args.qk_nope_head_dim self.qk_rope_head_dim = args.qk_rope_head_dim - self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim - self.v_head_dim = args.v_head_dim + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim = args.v_head_dim - if self.q_lora_rank == 0: - self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim) - else: - self.wq_a = Linear(self.dim, self.q_lora_rank) + # Query Projection + if self.q_lora_rank > 0: + self.wq_a = nn.Linear(self.dim, self.q_lora_rank) self.q_norm = RMSNorm(self.q_lora_rank) - self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) - self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) + self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) + else: + self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim) + + # Key/Value Projection + self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) self.kv_norm = RMSNorm(self.kv_lora_rank) - self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) + self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) + + # Output Projection self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) + + # Softmax Scaling self.softmax_scale = self.qk_head_dim ** -0.5 if args.max_seq_len > args.original_seq_len: mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 - self.softmax_scale = self.softmax_scale * mscale * mscale + self.softmax_scale *= mscale ** 2 + # Caching Buffers if attn_impl == "naive": self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False) self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False) @@ -442,57 +371,72 @@ class MLA(nn.Module): def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): """ - Forward pass for the Multi-Headed Attention Layer (MLA). + Forward pass for the MLA. Args: - x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). - start_pos (int): Starting position in the sequence for caching. - freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. - mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. + x (torch.Tensor): Input tensor (batch_size, seq_len, dim). + start_pos (int): Starting position for caching. + freqs_cis (torch.Tensor): Precomputed rotary embeddings. + mask (Optional[torch.Tensor]): Attention mask. Returns: - torch.Tensor: Output tensor with the same shape as the input. + torch.Tensor: Output tensor (batch_size, seq_len, dim). """ bsz, seqlen, _ = x.size() - end_pos = start_pos + seqlen - if self.q_lora_rank == 0: - q = self.wq(x) - else: - q = self.wq_b(self.q_norm(self.wq_a(x))) - q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) + end_pos = start_pos + seqlen + + # Compute Queries + q = self.wq(x) if self.q_lora_rank == 0 else self.wq_b(self.q_norm(self.wq_a(x))) + q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_pe = apply_rotary_emb(q_pe, freqs_cis) - kv = self.wkv_a(x) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + + # Compute Keys and Values + kv = self.wkv_a(x) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) + k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) + if attn_impl == "naive": - q = torch.cat([q_nope, q_pe], dim=-1) - kv = self.wkv_b(self.kv_norm(kv)) - kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) + q = torch.cat([q_nope, q_pe], dim=-1) + kv = self.wkv_b(self.kv_norm(kv)).view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) + k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) + + # Cache Keys and Values self.k_cache[:bsz, start_pos:end_pos] = k self.v_cache[:bsz, start_pos:end_pos] = v + + # Compute Attention Scores scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale else: - wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) + wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) + q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) + + # Cache KV and PE self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) - scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + - torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale + + # Compute Attention Scores + scores = ( + torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos]) + ) * self.softmax_scale + + # Apply Mask and Compute Softmax if mask is not None: scores += mask.unsqueeze(1) scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) + + # Compute Final Output if attn_impl == "naive": x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) else: x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) - x = self.wo(x.flatten(2)) - return x + return self.wo(x.flatten(2)) class MLP(nn.Module): """ @@ -543,6 +487,7 @@ class Gate(nn.Module): weight (torch.nn.Parameter): Learnable weights for the gate. bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. """ + def __init__(self, args: ModelArgs): """ Initializes the Gate module. @@ -551,14 +496,15 @@ class Gate(nn.Module): args (ModelArgs): Model arguments containing gating parameters. """ super().__init__() - self.dim = args.dim - self.topk = args.n_activated_experts - self.n_groups = args.n_expert_groups + self.dim = args.dim + self.topk = args.n_activated_experts + self.n_groups = args.n_expert_groups self.topk_groups = args.n_limited_groups - self.score_func = args.score_func + self.score_func = args.score_func self.route_scale = args.route_scale + self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) - self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None + self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -570,30 +516,92 @@ class Gate(nn.Module): Returns: Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. """ - scores = linear(x, self.weight) + scores = torch.matmul(x, self.weight.T) # More efficient than `linear(x, self.weight)` + if self.score_func == "softmax": scores = scores.softmax(dim=-1, dtype=torch.float32) else: scores = scores.sigmoid() - original_scores = scores + if self.bias is not None: - scores = scores + self.bias + scores += self.bias + if self.n_groups > 1: - scores = scores.view(x.size(0), self.n_groups, -1) - if self.bias is None: - group_scores = scores.amax(dim=-1) - else: - group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) + scores = scores.view(x.size(0), self.n_groups, -1) + group_scores = scores.amax(dim=-1) if self.bias is None else scores.topk(2, dim=-1)[0].sum(dim=-1) + indices = group_scores.topk(self.topk_groups, dim=-1)[1] - mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) + mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, 1) + scores = (scores * mask.unsqueeze(-1)).flatten(1) - indices = torch.topk(scores, self.topk, dim=-1)[1] - weights = original_scores.gather(1, indices) + + indices = scores.topk(self.topk, dim=-1)[1] + weights = scores.gather(1, indices) + if self.score_func == "sigmoid": weights /= weights.sum(dim=-1, keepdim=True) - weights *= self.route_scale - return weights.type_as(x), indices + return (weights * self.route_scale).type_as(x), indices + +# Reduces weight matrix size from š‘‚(š‘›^2) to O(nš‘Ÿ) where š‘Ÿā‰Ŗš‘› +class LowRankLinear(nn.Module): + def __init__(self, in_features, out_features, rank): + super().__init__() + self.U = nn.Parameter(torch.randn(in_features, rank)) + self.V = nn.Parameter(torch.randn(rank, out_features)) + + def forward(self, x): + return x @ self.U @ self.V # Approximates full weight multiplication + +# Reduces memory and compute cost. +# Helps when handling long sequences in Transformers. +class LinformerSelfAttention(nn.Module): + def __init__(self, dim, seq_len, k=256): + super().__init__() + self.proj = nn.Linear(seq_len, k) # Low-rank projection + self.Wq = nn.Linear(dim, dim) + self.Wk = nn.Linear(dim, dim) + self.Wv = nn.Linear(dim, dim) + + def forward(self, x): + Q, K, V = self.Wq(x), self.Wk(x), self.Wv(x) + K_proj = self.proj(K) # Reduce dimension + attn = (Q @ K_proj.transpose(-2, -1)) / (K_proj.size(-1) ** 0.5) + return attn @ V + +# Sparse computation: Avoids unnecessary multiplications. +# Faster execution: Great for hardware acceleration. +class ButterflyLinear(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.W = nn.Parameter(torch.randn(out_features, in_features)) + self.mask = self.create_butterfly_mask(in_features, out_features) + + def create_butterfly_mask(self, in_dim, out_dim): + mask = torch.zeros(out_dim, in_dim) + stride = max(1, in_dim // out_dim) + for i in range(out_dim): + mask[i, i * stride: (i + 1) * stride] = 1 + return mask + + def forward(self, x): + return (x @ self.W) * self.mask + +# Avoids redundant memory reads/writes. +# Speeds up training on long sequences (4K+ tokens). +class FlashSelfAttention(nn.Module): + def __init__(self, dim, num_heads): + super().__init__() + self.num_heads = num_heads + self.Wqkv = nn.Linear(dim, 3 * dim) # Merge Q, K, V + self.out_proj = nn.Linear(dim, dim) + + def forward(self, x): + Q, K, V = self.Wqkv(x).chunk(3, dim=-1) + Q, K, V = [t.view(t.size(0), -1, self.num_heads, t.size(-1) // self.num_heads).transpose(1, 2) for t in [Q, K, V]] + attn_out = flash_attn_func(Q, K, V) # Efficient Attention Computation + attn_out = attn_out.transpose(1, 2).contiguous().view(x.size(0), -1, x.size(-1)) + return self.out_proj(attn_out) class Expert(nn.Module): """ @@ -653,13 +661,13 @@ class MoE(nn.Module): super().__init__() self.dim = args.dim assert args.n_routed_experts % world_size == 0 - self.n_routed_experts = args.n_routed_experts - self.n_local_experts = args.n_routed_experts // world_size + self.n_routed_experts = args.n_routed_experts + self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts - self.experts_start_idx = rank * self.n_local_experts - self.experts_end_idx = self.experts_start_idx + self.n_local_experts - self.gate = Gate(args) - self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None + self.experts_start_idx = rank * self.n_local_experts + self.experts_end_idx = self.experts_start_idx + self.n_local_experts + self.gate = Gate(args) + self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None for i in range(self.n_routed_experts)]) self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) @@ -673,15 +681,16 @@ class MoE(nn.Module): Returns: torch.Tensor: Output tensor after expert routing and computation. """ - shape = x.size() - x = x.view(-1, self.dim) + shape = x.size() + x = x.view(-1, self.dim) weights, indices = self.gate(x) - y = torch.zeros_like(x) + y = torch.zeros_like(x) + counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() for i in range(self.experts_start_idx, self.experts_end_idx): if counts[i] == 0: continue - expert = self.experts[i] + expert = self.experts[i] idx, top = torch.where(indices == i) y[idx] += expert(x[idx]) * weights[idx, top, None] z = self.shared_experts(x) @@ -709,10 +718,10 @@ class Block(nn.Module): args (ModelArgs): Model arguments containing block parameters. """ super().__init__() - self.attn = MLA(args) - self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) + self.attn = MLA(args) + self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) self.attn_norm = RMSNorm(args.dim) - self.ffn_norm = RMSNorm(args.dim) + self.ffn_norm = RMSNorm(args.dim) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: """ @@ -739,11 +748,12 @@ class Transformer(nn.Module): Attributes: max_seq_len (int): Maximum sequence length for the transformer. embed (nn.Module): Embedding layer for input tokens. - layers (torch.nn.ModuleList): List of transformer blocks. + layers (nn.ModuleList): List of transformer blocks. norm (nn.Module): Layer normalization applied after all blocks. head (nn.Module): Output projection layer mapping to vocabulary size. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. """ + def __init__(self, args: ModelArgs): """ Initializes the Transformer model. @@ -751,22 +761,27 @@ class Transformer(nn.Module): Args: args (ModelArgs): Model arguments containing transformer parameters. """ - global world_size, rank - world_size = dist.get_world_size() if dist.is_initialized() else 1 - rank = dist.get_rank() if dist.is_initialized() else 0 - Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 super().__init__() self.max_seq_len = args.max_seq_len - self.embed = ParallelEmbedding(args.vocab_size, args.dim) - self.layers = torch.nn.ModuleList() - for layer_id in range(args.n_layers): - self.layers.append(Block(layer_id, args)) - self.norm = RMSNorm(args.dim) - self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype()) + self.embed = ParallelEmbedding(args.vocab_size, args.dim) + self.layers = nn.ModuleList([Block(layer_id, args) for layer_id in range(args.n_layers)]) + self.norm = RMSNorm(args.dim) + self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype()) self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) + # Distributed setup + if dist.is_initialized(): + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + else: + self.world_size = 1 + self.rank = 0 + + # Set dtype for Linear layers based on model arguments + Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 + @torch.inference_mode() - def forward(self, tokens: torch.Tensor, start_pos: int = 0): + def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor: """ Forward pass for the Transformer model. @@ -777,20 +792,24 @@ class Transformer(nn.Module): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ - seqlen = tokens.size(1) - h = self.embed(tokens) - freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] + seqlen = tokens.shape[1] + h = self.embed(tokens) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + mask = None if seqlen > 1: mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) + for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) - h = self.norm(h)[:, -1] - logits = self.head(h) - if world_size > 1: - all_logits = [torch.empty_like(logits) for _ in range(world_size)] - dist.all_gather(all_logits, logits) - logits = torch.cat(all_logits, dim=-1) + + logits = self.head(self.norm(h)[:, -1]) + + if self.world_size > 1: + gathered_logits = [torch.empty_like(logits) for _ in range(self.world_size)] + dist.all_gather(gathered_logits, logits) + logits = torch.cat(gathered_logits, dim=-1) + return logits