import math from dataclasses import dataclass from typing import Tuple, Optional, Literal import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist from kernel import act_quant, weight_dequant, fp8_gemm # Constants FLOAT_NEG_INF = float("-inf") DEFAULT_EPS = 1e-6 @dataclass class ModelArgs: """ Data class for defining model arguments and hyperparameters. """ max_batch_size: int = 8 max_seq_len: int = 4096 * 4 dtype: Literal["bf16", "fp8"] = "bf16" vocab_size: int = 102400 dim: int = 2048 inter_dim: int = 10944 moe_inter_dim: int = 1408 n_layers: int = 27 n_dense_layers: int = 1 n_heads: int = 16 n_routed_experts: int = 64 n_shared_experts: int = 2 n_activated_experts: int = 6 n_expert_groups: int = 1 n_limited_groups: int = 1 score_func: Literal["softmax", "sigmoid"] = "softmax" route_scale: float = 1.0 q_lora_rank: int = 0 kv_lora_rank: int = 512 qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 original_seq_len: int = 4096 rope_theta: float = 10000.0 rope_factor: float = 40 beta_fast: int = 32 beta_slow: int = 1 mscale: float = 1.0 def initialize_distributed_settings(): """Initialize distributed settings for multi-GPU training.""" global world_size, rank world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 class ParallelEmbedding(nn.Module): """ Embedding layer with parallelism support across distributed processes. """ def __init__(self, vocab_size: int, dim: int): super().__init__() assert vocab_size % world_size == 0, "Vocabulary size must be divisible by world size." self.part_vocab_size = vocab_size // world_size self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size self.weight = nn.Parameter(torch.empty(self.part_vocab_size, dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for parallel embedding layer. """ if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) x = x - self.vocab_start_idx x[mask] = 0 y = F.embedding(x, self.weight) if world_size > 1: y[mask] = 0 dist.all_reduce(y) return y # The rest of your classes (Linear, ColumnParallelLinear, RowParallelLinear, etc.) would follow suit. # For brevity, I won't rewrite all of them, but ensure to apply the same principles above. class Transformer(nn.Module): """ Transformer model with positional embeddings, multiple layers, and output projection. """ def __init__(self, args: ModelArgs): initialize_distributed_settings() # Initialize distributed settings Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 super().__init__() self.max_seq_len = args.max_seq_len self.embed = ParallelEmbedding(args.vocab_size, args.dim) self.layers = torch.nn.ModuleList([Block(layer_id, args) for layer_id in range(args.n_layers)]) self.norm = RMSNorm(args.dim) self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype()) self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) @torch.inference_mode() def forward(self, tokens: torch.Tensor, start_pos: int = 0): """ Forward pass for the Transformer model. """ seqlen = tokens.size(1) h = self.embed(tokens) freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen] mask = None if seqlen > 1: mask = torch.full((seqlen, seqlen), FLOAT_NEG_INF, device=tokens.device).triu_(1) for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) h = self.norm(h)[:, -1] logits = self.head(h) if world_size > 1: all_logits = [torch.empty_like(logits) for _ in range(world_size)] dist.all_gather(all_logits, logits) logits = torch.cat(all_logits, dim=-1) return logits if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) torch.set_default_device("cuda") torch.manual_seed(0) args = ModelArgs() x = torch.randint(0, args.vocab_size, (2, 128)) model = Transformer(args) print(model(x).size())