This commit is contained in:
Furkan KARAKUZ 2025-04-13 17:28:44 +00:00 committed by GitHub
commit 2a47abbf2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 23 additions and 23 deletions

View File

@ -30,7 +30,7 @@ mapping = {
} }
def main(hf_ckpt_path, save_path, n_experts, mp): def main(hf_ckpt_path: str, save_path: str, n_experts: str, mp: int) -> None:
""" """
Converts and saves model checkpoint files into a specified format. Converts and saves model checkpoint files into a specified format.

View File

@ -9,7 +9,7 @@ from safetensors.torch import load_file, save_file
from kernel import weight_dequant from kernel import weight_dequant
def main(fp8_path, bf16_path): def main(fp8_path: str, bf16_path: str) -> None:
""" """
Converts FP8 weights to BF16 and saves the converted weights. Converts FP8 weights to BF16 and saves the converted weights.
@ -41,7 +41,7 @@ def main(fp8_path, bf16_path):
fp8_weight_names = [] fp8_weight_names = []
# Helper function to get tensor from the correct file # Helper function to get tensor from the correct file
def get_tensor(tensor_name): def get_tensor(tensor_name: str) -> torch.Tensor:
""" """
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.

View File

@ -11,7 +11,7 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0): def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
""" """
Samples a token from the logits using temperature scaling. Samples a token from the logits using temperature scaling.

View File

@ -7,7 +7,7 @@ from triton import Config
@triton.jit @triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): def act_quant_kernel(x_ptr: triton.Pointer, y_ptr: triton.Pointer, s_ptr: triton.Pointer, BLOCK_SIZE: tl.constexpr) -> None:
""" """
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
@ -53,7 +53,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
@triton.jit @triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): def weight_dequant_kernel(x_ptr: tl.pointer, s_ptr: tl.pointer, y_ptr: tl.pointer, M: int, N: int, BLOCK_SIZE: tl.constexpr) -> None:
""" """
Dequantizes weights using the provided scaling factors and stores the result. Dequantizes weights using the provided scaling factors and stores the result.
@ -112,12 +112,12 @@ fp8_gemm_configs = [
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit @triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, def fp8_gemm_kernel(a_ptr: tl.tensor, b_ptr: tl.tensor, c_ptr:tl.tensor,
a_s_ptr, b_s_ptr, a_s_ptr: tl.tensor, b_s_ptr: tl.tensor,
M, N: tl.constexpr, K: tl.constexpr, M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr): BLOCK_SIZE_K: tl.constexpr) -> None:
""" """
Performs a matrix multiplication operation on FP8 matrices with scaling factors. Performs a matrix multiplication operation on FP8 matrices with scaling factors.
@ -167,7 +167,7 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
tl.store(c_ptrs, c, mask=mask) tl.store(c_ptrs, c, mask=mask)
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
""" """
Perform a matrix multiplication using FP8 precision. Perform a matrix multiplication using FP8 precision.

View File

@ -92,7 +92,7 @@ class ParallelEmbedding(nn.Module):
vocab_size (int): Vocabulary size. vocab_size (int): Vocabulary size.
dim (int): Embedding dimension. dim (int): Embedding dimension.
""" """
def __init__(self, vocab_size: int, dim: int): def __init__(self, vocab_size: int, dim: int) -> None:
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.dim = dim self.dim = dim
@ -173,7 +173,7 @@ class Linear(nn.Module):
""" """
dtype = torch.bfloat16 dtype = torch.bfloat16
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
@ -212,7 +212,7 @@ class ColumnParallelLinear(Linear):
bias (bool): Whether to include a bias term. Defaults to False. bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
""" """
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
self.part_out_features = out_features // world_size self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype) super().__init__(in_features, self.part_out_features, bias, dtype)
@ -241,7 +241,7 @@ class RowParallelLinear(Linear):
bias (bool): Whether to include a bias term. Defaults to False. bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
""" """
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
self.part_in_features = in_features // world_size self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype) super().__init__(self.part_in_features, out_features, bias, dtype)
@ -406,7 +406,7 @@ class MLA(nn.Module):
v_head_dim (int): Dimensionality of value projections. v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation. softmax_scale (float): Scaling factor for softmax in attention computation.
""" """
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs) -> None:
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
self.n_heads = args.n_heads self.n_heads = args.n_heads
@ -440,7 +440,7 @@ class MLA(nn.Module):
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
""" """
Forward pass for the Multi-Head Latent Attention (MLA) Layer. Forward pass for the Multi-Head Latent Attention (MLA) Layer.
@ -503,7 +503,7 @@ class MLP(nn.Module):
w2 (nn.Module): Linear layer for hidden-to-output transformation. w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation. w3 (nn.Module): Additional linear layer for feature transformation.
""" """
def __init__(self, dim: int, inter_dim: int): def __init__(self, dim: int, inter_dim: int) -> None:
""" """
Initializes the MLP layer. Initializes the MLP layer.
@ -543,7 +543,7 @@ class Gate(nn.Module):
weight (torch.nn.Parameter): Learnable weights for the gate. weight (torch.nn.Parameter): Learnable weights for the gate.
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
""" """
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs) -> None:
""" """
Initializes the Gate module. Initializes the Gate module.
@ -604,7 +604,7 @@ class Expert(nn.Module):
w2 (nn.Module): Linear layer for hidden-to-output transformation. w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation. w3 (nn.Module): Additional linear layer for feature transformation.
""" """
def __init__(self, dim: int, inter_dim: int): def __init__(self, dim: int, inter_dim: int) -> None:
""" """
Initializes the Expert layer. Initializes the Expert layer.
@ -643,7 +643,7 @@ class MoE(nn.Module):
experts (nn.ModuleList): List of expert modules. experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs. shared_experts (nn.Module): Shared experts applied to all inputs.
""" """
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs) -> None:
""" """
Initializes the MoE module. Initializes the MoE module.
@ -700,7 +700,7 @@ class Block(nn.Module):
attn_norm (nn.Module): Layer normalization for attention. attn_norm (nn.Module): Layer normalization for attention.
ffn_norm (nn.Module): Layer normalization for feed-forward network. ffn_norm (nn.Module): Layer normalization for feed-forward network.
""" """
def __init__(self, layer_id: int, args: ModelArgs): def __init__(self, layer_id: int, args: ModelArgs) -> None:
""" """
Initializes the Transformer block. Initializes the Transformer block.
@ -744,7 +744,7 @@ class Transformer(nn.Module):
head (nn.Module): Output projection layer mapping to vocabulary size. head (nn.Module): Output projection layer mapping to vocabulary size.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
""" """
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs) -> None:
""" """
Initializes the Transformer model. Initializes the Transformer model.
@ -766,7 +766,7 @@ class Transformer(nn.Module):
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
@torch.inference_mode() @torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int = 0): def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
""" """
Forward pass for the Transformer model. Forward pass for the Transformer model.