diff --git a/.gitignore b/.gitignore index 9cc511e..68f1d27 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,8 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ + +.vscode/* + +.DS_Store \ No newline at end of file diff --git a/inference/convert.py b/inference/convert.py index f6fb5e2..c606ce8 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -31,6 +31,18 @@ mapping = { def main(hf_ckpt_path, save_path, n_experts, mp): + """ + Converts and saves model checkpoint files into a specified format. + + Args: + hf_ckpt_path (str): Path to the directory containing the input checkpoint files. + save_path (str): Path to the directory where the converted checkpoint files will be saved. + n_experts (int): Total number of experts in the model. + mp (int): Model parallelism factor. + + Returns: + None + """ torch.set_num_threads(8) n_local_experts = n_experts // mp state_dicts = [{} for _ in range(mp)] diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 1b9735a..4037342 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -10,6 +10,25 @@ from safetensors.torch import load_file, save_file from kernel import weight_dequant def main(fp8_path, bf16_path): + """ + Converts FP8 weights to BF16 and saves the converted weights. + + This function reads FP8 weights from the specified directory, converts them to BF16, + and saves the converted weights to another specified directory. It also updates the + model index file to reflect the changes. + + Args: + fp8_path (str): The path to the directory containing the FP8 weights and model index file. + bf16_path (str): The path to the directory where the converted BF16 weights will be saved. + + Raises: + KeyError: If a required scale_inv tensor is missing for a weight. + + Notes: + - The function assumes that the FP8 weights are stored in safetensor files. + - The function caches loaded safetensor files to optimize memory usage. + - The function updates the model index file to remove references to scale_inv tensors. + """ torch.set_default_dtype(torch.bfloat16) os.makedirs(bf16_path, exist_ok=True) model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") @@ -23,6 +42,18 @@ def main(fp8_path, bf16_path): # Helper function to get tensor from the correct file def get_tensor(tensor_name): + """ + Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. + + Args: + tensor_name (str): The name of the tensor to retrieve. + + Returns: + torch.Tensor: The retrieved tensor. + + Raises: + KeyError: If the tensor does not exist in the safetensor file. + """ file_name = weight_map[tensor_name] if file_name not in loaded_files: file_path = os.path.join(fp8_path, file_name) diff --git a/inference/generate.py b/inference/generate.py index a08c7bd..fbf3ab8 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -12,6 +12,16 @@ from model import Transformer, ModelArgs def sample(logits, temperature: float = 1.0): + """ + Samples a token from the logits using temperature scaling. + + Args: + logits (torch.Tensor): The logits tensor for token predictions. + temperature (float, optional): Temperature for scaling logits. Defaults to 1.0. + + Returns: + torch.Tensor: The sampled token. + """ logits = logits / max(temperature, 1e-5) probs = torch.softmax(logits, dim=-1) return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) @@ -25,6 +35,19 @@ def generate( eos_id: int, temperature: float = 1.0 ) -> List[List[int]]: + """ + Generates new tokens based on the given prompt tokens using the specified model. + + Args: + model (Transformer): The transformer model used for token generation. + prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence. + max_new_tokens (int): The maximum number of new tokens to generate. + eos_id (int): The end-of-sequence token ID. + temperature (float, optional): The temperature value for sampling. Defaults to 1.0. + + Returns: + List[List[int]]: A list of lists containing the generated tokens for each sequence. + """ prompt_lens = [len(t) for t in prompt_tokens] assert max(prompt_lens) <= model.max_seq_len total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) @@ -63,6 +86,17 @@ def main( max_new_tokens: int = 100, temperature: float = 1.0, ) -> None: + """ + Main function to load the model and perform interactive or batch text generation. + + Args: + ckpt_path (str): Path to the model checkpoint directory. + config (str): Path to the model configuration file. + input_file (str, optional): Path to a file containing input prompts. Defaults to "". + interactive (bool, optional): Whether to run in interactive mode. Defaults to True. + max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. + temperature (float, optional): Temperature for sampling. Defaults to 1.0. + """ world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0")) @@ -125,6 +159,20 @@ def main( if __name__ == "__main__": + """ + Command-line interface for distributed text generation. + + Arguments: + --ckpt-path (str): Path to the model checkpoint directory. + --config (str): Path to the model configuration file. + --input-file (str, optional): File containing prompts for batch processing. + --interactive (bool, optional): Enable interactive mode for generating text. + --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200. + --temperature (float, optional): Temperature for sampling. Defaults to 0.2. + + Raises: + AssertionError: If neither input-file nor interactive mode is specified. + """ parser = ArgumentParser() parser.add_argument("--ckpt-path", type=str, required=True) parser.add_argument("--config", type=str, required=True) diff --git a/inference/kernel.py b/inference/kernel.py index 19e8f8f..dec8639 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -8,6 +8,18 @@ from triton import Config @triton.jit def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. + + Args: + x_ptr (triton.Pointer): Pointer to the input tensor. + y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. + s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + + Returns: + None + """ pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) @@ -19,6 +31,18 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes the input tensor `x` using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + """ assert x.is_contiguous() assert x.size(-1) % block_size == 0 y = torch.empty_like(x, dtype=torch.float8_e4m3fn) @@ -30,6 +54,20 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor @triton.jit def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using the provided scaling factors and stores the result. + + Args: + x_ptr (tl.pointer): Pointer to the quantized weights. + s_ptr (tl.pointer): Pointer to the scaling factors. + y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) n = tl.cdiv(N, BLOCK_SIZE) @@ -44,6 +82,20 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantizes the given weight tensor using the provided scale tensor. + + Args: + x (torch.Tensor): The quantized weight tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M, N). + block_size (int, optional): The block size to use for dequantization. Defaults to 128. + + Returns: + torch.Tensor: The dequantized weight tensor of the same shape as `x`. + + Raises: + AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + """ assert x.is_contiguous() and s.is_contiguous() assert x.dim() == 2 and s.dim() == 2 M, N = x.size() @@ -66,6 +118,25 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + """ + Performs a matrix multiplication operation on FP8 matrices with scaling factors. + + Args: + a_ptr (tl.tensor): Pointer to the first input matrix A. + b_ptr (tl.tensor): Pointer to the second input matrix B. + c_ptr (tl.tensor): Pointer to the output matrix C. + a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. + b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + M (int): Number of rows in matrix A and C. + N (tl.constexpr): Number of columns in matrix B and C. + K (tl.constexpr): Number of columns in matrix A and rows in matrix B. + BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. + BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. + BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. + + Returns: + None + """ pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) k = tl.cdiv(K, BLOCK_SIZE_K) @@ -97,6 +168,18 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): + """ + Perform a matrix multiplication using FP8 precision. + + Args: + a (torch.Tensor): The first input matrix, must be contiguous. + a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): The second input matrix, must be contiguous. + b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ assert a.is_contiguous() and b.is_contiguous() assert a_s.is_contiguous() and b_s.is_contiguous() K = a.size(-1) diff --git a/inference/model.py b/inference/model.py index 11004a0..9ea60c9 100644 --- a/inference/model.py +++ b/inference/model.py @@ -18,6 +18,39 @@ attn_impl: Literal["naive", "absorb"] = "absorb" @dataclass class ModelArgs: + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + mscale (float): Scaling factor for extended attention. + """ max_batch_size: int = 8 max_seq_len: int = 4096 * 4 dtype: Literal["bf16", "fp8"] = "bf16" @@ -52,6 +85,13 @@ class ModelArgs: class ParallelEmbedding(nn.Module): + """ + Embedding layer with parallelism support across distributed processes. + + Args: + vocab_size (int): Vocabulary size. + dim (int): Embedding dimension. + """ def __init__(self, vocab_size: int, dim: int): super().__init__() self.vocab_size = vocab_size @@ -63,6 +103,18 @@ class ParallelEmbedding(nn.Module): self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for parallel embedding layer. + + Args: + x (torch.Tensor): Input tensor containing token indices. + + Returns: + torch.Tensor: Embedded representations. + + Raises: + ValueError: If `world_size` is not defined. + """ if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) x = x - self.vocab_start_idx @@ -75,6 +127,27 @@ class ParallelEmbedding(nn.Module): def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Applies a linear transformation to the incoming data: y = xA^T + b. + This function supports specialized implementations based on quantization + and tensor formats. + + Args: + x (torch.Tensor): The input tensor. + weight (torch.Tensor): The weight tensor. It may be quantized and + requires dequantization for certain cases. + bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. + + Returns: + torch.Tensor: The result of the linear transformation, which may involve + quantization-aware computations depending on the input parameters. + + Notes: + - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version + is used for computation. + - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. + - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. + """ if weight.element_size() > 1: return F.linear(x, weight, bias) elif gemm_impl == "bf16": @@ -89,6 +162,15 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = class Linear(nn.Module): + """ + Custom linear layer with support for quantized weights and optional bias. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ dtype = torch.bfloat16 def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): @@ -108,27 +190,72 @@ class Linear(nn.Module): self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the custom linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor after linear computation. + """ return linear(x, self.weight, self.bias) class ColumnParallelLinear(Linear): + """ + Linear layer with column parallelism, splitting output features across distributed processes. + + Args: + in_features (int): Number of input features. + out_features (int): Total number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): assert out_features % world_size == 0 self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for column parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with column-parallel computation. + """ y = linear(x, self.weight, self.bias) return y class RowParallelLinear(Linear): + """ + Linear layer with row parallelism, splitting input features across distributed processes. + + Args: + in_features (int): Total number of input features. + out_features (int): Number of output features. + bias (bool): Whether to include a bias term. Defaults to False. + dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. + """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): assert in_features % world_size == 0 self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for row parallel linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Transformed tensor with row-parallel computation. + """ y = linear(x, self.weight) if world_size > 1: dist.all_reduce(y) @@ -138,6 +265,13 @@ class RowParallelLinear(Linear): class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization (RMSNorm). + + Args: + dim (int): Dimension of the input tensor. + eps (float): Epsilon value for numerical stability. Defaults to 1e-6. + """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim @@ -145,10 +279,28 @@ class RMSNorm(nn.Module): self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor): + """ + Forward pass for RMSNorm. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor with the same shape as input. + """ return F.rms_norm(x, (self.dim,), self.weight, self.eps) def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (ModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ dim = args.qk_rope_head_dim seqlen = args.max_seq_len beta_fast = args.beta_fast @@ -157,14 +309,51 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: factor = args.rope_factor def find_correction_dim(num_rotations, dim, base, max_seq_len): + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) return max(low, 0), min(high, dim-1) def linear_ramp_factor(min, max, dim): + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ if min == max: max += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) @@ -184,6 +373,16 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ dtype = x.dtype x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) @@ -192,6 +391,21 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: class MLA(nn.Module): + """ + Multi-Headed Attention Layer (MLA). + + Attributes: + dim (int): Dimensionality of the input features. + n_heads (int): Number of attention heads. + n_local_heads (int): Number of local attention heads for distributed systems. + q_lora_rank (int): Rank for low-rank query projection. + kv_lora_rank (int): Rank for low-rank key/value projection. + qk_nope_head_dim (int): Dimensionality of non-positional query/key projections. + qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections. + qk_head_dim (int): Total dimensionality of query/key projections. + v_head_dim (int): Dimensionality of value projections. + softmax_scale (float): Scaling factor for softmax in attention computation. + """ def __init__(self, args: ModelArgs): super().__init__() self.dim = args.dim @@ -227,6 +441,18 @@ class MLA(nn.Module): self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + """ + Forward pass for the Multi-Headed Attention Layer (MLA). + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + start_pos (int): Starting position in the sequence for caching. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ bsz, seqlen, _ = x.size() end_pos = start_pos + seqlen if self.q_lora_rank == 0: @@ -269,18 +495,61 @@ class MLA(nn.Module): class MLP(nn.Module): + """ + Multi-Layer Perceptron (MLP) used as a feed-forward layer. + + Attributes: + w1 (nn.Module): Linear layer for input-to-hidden transformation. + w2 (nn.Module): Linear layer for hidden-to-output transformation. + w3 (nn.Module): Additional linear layer for feature transformation. + """ def __init__(self, dim: int, inter_dim: int): + """ + Initializes the MLP layer. + + Args: + dim (int): Input and output dimensionality. + inter_dim (int): Hidden layer dimensionality. + """ super().__init__() self.w1 = ColumnParallelLinear(dim, inter_dim) self.w2 = RowParallelLinear(inter_dim, dim) self.w3 = ColumnParallelLinear(dim, inter_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the MLP layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after MLP computation. + """ return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Gate(nn.Module): + """ + Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. + + Attributes: + dim (int): Dimensionality of input features. + topk (int): Number of top experts activated for each input. + n_groups (int): Number of groups for routing. + topk_groups (int): Number of groups to route inputs to. + score_func (str): Scoring function ('softmax' or 'sigmoid'). + route_scale (float): Scaling factor for routing weights. + weight (torch.nn.Parameter): Learnable weights for the gate. + bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. + """ def __init__(self, args: ModelArgs): + """ + Initializes the Gate module. + + Args: + args (ModelArgs): Model arguments containing gating parameters. + """ super().__init__() self.dim = args.dim self.topk = args.n_activated_experts @@ -292,6 +561,15 @@ class Gate(nn.Module): self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass for the gating mechanism. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. + """ scores = linear(x, self.weight) if self.score_func == "softmax": scores = scores.softmax(dim=-1, dtype=torch.float32) @@ -318,18 +596,60 @@ class Gate(nn.Module): class Expert(nn.Module): + """ + Expert layer for Mixture-of-Experts (MoE) models. + + Attributes: + w1 (nn.Module): Linear layer for input-to-hidden transformation. + w2 (nn.Module): Linear layer for hidden-to-output transformation. + w3 (nn.Module): Additional linear layer for feature transformation. + """ def __init__(self, dim: int, inter_dim: int): + """ + Initializes the Expert layer. + + Args: + dim (int): Input and output dimensionality. + inter_dim (int): Hidden layer dimensionality. + """ super().__init__() self.w1 = Linear(dim, inter_dim) self.w2 = Linear(inter_dim, dim) self.w3 = Linear(dim, inter_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the Expert layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after expert computation. + """ return self.w2(F.silu(self.w1(x)) * self.w3(x)) class MoE(nn.Module): + """ + Mixture-of-Experts (MoE) module. + + Attributes: + dim (int): Dimensionality of input features. + n_routed_experts (int): Total number of experts in the model. + n_local_experts (int): Number of experts handled locally in distributed systems. + n_activated_experts (int): Number of experts activated for each input. + gate (nn.Module): Gating mechanism to route inputs to experts. + experts (nn.ModuleList): List of expert modules. + shared_experts (nn.Module): Shared experts applied to all inputs. + """ def __init__(self, args: ModelArgs): + """ + Initializes the MoE module. + + Args: + args (ModelArgs): Model arguments containing MoE parameters. + """ super().__init__() self.dim = args.dim assert args.n_routed_experts % world_size == 0 @@ -344,6 +664,15 @@ class MoE(nn.Module): self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the MoE module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after expert routing and computation. + """ shape = x.size() x = x.view(-1, self.dim) weights, indices = self.gate(x) @@ -362,7 +691,23 @@ class MoE(nn.Module): class Block(nn.Module): + """ + Transformer block combining attention and feed-forward layers. + + Attributes: + attn (nn.Module): Attention layer (MLA). + ffn (nn.Module): Feed-forward network (MLP or MoE). + attn_norm (nn.Module): Layer normalization for attention. + ffn_norm (nn.Module): Layer normalization for feed-forward network. + """ def __init__(self, layer_id: int, args: ModelArgs): + """ + Initializes the Transformer block. + + Args: + layer_id (int): Layer index in the transformer. + args (ModelArgs): Model arguments containing block parameters. + """ super().__init__() self.attn = MLA(args) self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) @@ -370,13 +715,42 @@ class Block(nn.Module): self.ffn_norm = RMSNorm(args.dim) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position in the sequence. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. + + Returns: + torch.Tensor: Output tensor after block computation. + """ x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) x = x + self.ffn(self.ffn_norm(x)) return x class Transformer(nn.Module): + """ + Transformer model with positional embeddings, multiple layers, and output projection. + + Attributes: + max_seq_len (int): Maximum sequence length for the transformer. + embed (nn.Module): Embedding layer for input tokens. + layers (torch.nn.ModuleList): List of transformer blocks. + norm (nn.Module): Layer normalization applied after all blocks. + head (nn.Module): Output projection layer mapping to vocabulary size. + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + """ def __init__(self, args: ModelArgs): + """ + Initializes the Transformer model. + + Args: + args (ModelArgs): Model arguments containing transformer parameters. + """ global world_size, rank world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 @@ -393,6 +767,16 @@ class Transformer(nn.Module): @torch.inference_mode() def forward(self, tokens: torch.Tensor, start_pos: int = 0): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0. + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ seqlen = tokens.size(1) h = self.embed(tokens) freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]