From 1e40f4a73c71ba2e9a5fc3dd58dafcf594c87449 Mon Sep 17 00:00:00 2001 From: furkankarakuz Date: Sun, 13 Apr 2025 20:25:51 +0300 Subject: [PATCH] add type hints to functions --- inference/convert.py | 2 +- inference/fp8_cast_bf16.py | 4 ++-- inference/generate.py | 2 +- inference/kernel.py | 12 ++++++------ inference/model.py | 26 +++++++++++++------------- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 6d85ccc..ac19398 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -30,7 +30,7 @@ mapping = { } -def main(hf_ckpt_path, save_path, n_experts, mp): +def main(hf_ckpt_path: str, save_path: str, n_experts: str, mp: int) -> None: """ Converts and saves model checkpoint files into a specified format. diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..3b50eb8 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -9,7 +9,7 @@ from safetensors.torch import load_file, save_file from kernel import weight_dequant -def main(fp8_path, bf16_path): +def main(fp8_path: str, bf16_path: str) -> None: """ Converts FP8 weights to BF16 and saves the converted weights. @@ -41,7 +41,7 @@ def main(fp8_path, bf16_path): fp8_weight_names = [] # Helper function to get tensor from the correct file - def get_tensor(tensor_name): + def get_tensor(tensor_name: str) -> torch.Tensor: """ Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. diff --git a/inference/generate.py b/inference/generate.py index 7e9bffe..16af7f6 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -11,7 +11,7 @@ from safetensors.torch import load_model from model import Transformer, ModelArgs -def sample(logits, temperature: float = 1.0): +def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: """ Samples a token from the logits using temperature scaling. diff --git a/inference/kernel.py b/inference/kernel.py index ba18dca..314b3d4 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -7,7 +7,7 @@ from triton import Config @triton.jit -def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): +def act_quant_kernel(x_ptr: triton.Pointer, y_ptr: triton.Pointer, s_ptr: triton.Pointer, BLOCK_SIZE: tl.constexpr) -> None: """ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. @@ -53,7 +53,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor @triton.jit -def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): +def weight_dequant_kernel(x_ptr: tl.pointer, s_ptr: tl.pointer, y_ptr: tl.pointer, M: int, N: int, BLOCK_SIZE: tl.constexpr) -> None: """ Dequantizes weights using the provided scaling factors and stores the result. @@ -112,12 +112,12 @@ fp8_gemm_configs = [ @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) @triton.jit -def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, - a_s_ptr, b_s_ptr, +def fp8_gemm_kernel(a_ptr: tl.tensor, b_ptr: tl.tensor, c_ptr:tl.tensor, + a_s_ptr: tl.tensor, b_s_ptr: tl.tensor, M, N: tl.constexpr, K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): + BLOCK_SIZE_K: tl.constexpr) -> None: """ Performs a matrix multiplication operation on FP8 matrices with scaling factors. @@ -167,7 +167,7 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, tl.store(c_ptrs, c, mask=mask) -def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor: """ Perform a matrix multiplication using FP8 precision. diff --git a/inference/model.py b/inference/model.py index c143e97..7c43480 100644 --- a/inference/model.py +++ b/inference/model.py @@ -92,7 +92,7 @@ class ParallelEmbedding(nn.Module): vocab_size (int): Vocabulary size. dim (int): Embedding dimension. """ - def __init__(self, vocab_size: int, dim: int): + def __init__(self, vocab_size: int, dim: int) -> None: super().__init__() self.vocab_size = vocab_size self.dim = dim @@ -173,7 +173,7 @@ class Linear(nn.Module): """ dtype = torch.bfloat16 - def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None: super().__init__() self.in_features = in_features self.out_features = out_features @@ -212,7 +212,7 @@ class ColumnParallelLinear(Linear): bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ - def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None: assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) @@ -241,7 +241,7 @@ class RowParallelLinear(Linear): bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ - def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): + def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None: assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) @@ -406,7 +406,7 @@ class MLA(nn.Module): v_head_dim (int): Dimensionality of value projections. softmax_scale (float): Scaling factor for softmax in attention computation. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: ModelArgs) -> None: super().__init__() self.dim = args.dim self.n_heads = args.n_heads @@ -440,7 +440,7 @@ class MLA(nn.Module): self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) - def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. @@ -503,7 +503,7 @@ class MLP(nn.Module): w2 (nn.Module): Linear layer for hidden-to-output transformation. w3 (nn.Module): Additional linear layer for feature transformation. """ - def __init__(self, dim: int, inter_dim: int): + def __init__(self, dim: int, inter_dim: int) -> None: """ Initializes the MLP layer. @@ -543,7 +543,7 @@ class Gate(nn.Module): weight (torch.nn.Parameter): Learnable weights for the gate. bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: ModelArgs) -> None: """ Initializes the Gate module. @@ -604,7 +604,7 @@ class Expert(nn.Module): w2 (nn.Module): Linear layer for hidden-to-output transformation. w3 (nn.Module): Additional linear layer for feature transformation. """ - def __init__(self, dim: int, inter_dim: int): + def __init__(self, dim: int, inter_dim: int) -> None: """ Initializes the Expert layer. @@ -643,7 +643,7 @@ class MoE(nn.Module): experts (nn.ModuleList): List of expert modules. shared_experts (nn.Module): Shared experts applied to all inputs. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: ModelArgs) -> None: """ Initializes the MoE module. @@ -700,7 +700,7 @@ class Block(nn.Module): attn_norm (nn.Module): Layer normalization for attention. ffn_norm (nn.Module): Layer normalization for feed-forward network. """ - def __init__(self, layer_id: int, args: ModelArgs): + def __init__(self, layer_id: int, args: ModelArgs) -> None: """ Initializes the Transformer block. @@ -744,7 +744,7 @@ class Transformer(nn.Module): head (nn.Module): Output projection layer mapping to vocabulary size. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. """ - def __init__(self, args: ModelArgs): + def __init__(self, args: ModelArgs) -> None: """ Initializes the Transformer model. @@ -766,7 +766,7 @@ class Transformer(nn.Module): self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) @torch.inference_mode() - def forward(self, tokens: torch.Tensor, start_pos: int = 0): + def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor: """ Forward pass for the Transformer model.