From 1e40f4a73c71ba2e9a5fc3dd58dafcf594c87449 Mon Sep 17 00:00:00 2001
From: furkankarakuz <karakuzfurkan.98@gmail.com>
Date: Sun, 13 Apr 2025 20:25:51 +0300
Subject: [PATCH] add type hints to functions

---
 inference/convert.py       |  2 +-
 inference/fp8_cast_bf16.py |  4 ++--
 inference/generate.py      |  2 +-
 inference/kernel.py        | 12 ++++++------
 inference/model.py         | 26 +++++++++++++-------------
 5 files changed, 23 insertions(+), 23 deletions(-)

diff --git a/inference/convert.py b/inference/convert.py
index 6d85ccc..ac19398 100644
--- a/inference/convert.py
+++ b/inference/convert.py
@@ -30,7 +30,7 @@ mapping = {
 }
 
 
-def main(hf_ckpt_path, save_path, n_experts, mp):
+def main(hf_ckpt_path: str, save_path: str, n_experts: str, mp: int) -> None:
     """
     Converts and saves model checkpoint files into a specified format.
 
diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py
index 4037342..3b50eb8 100644
--- a/inference/fp8_cast_bf16.py
+++ b/inference/fp8_cast_bf16.py
@@ -9,7 +9,7 @@ from safetensors.torch import load_file, save_file
 
 from kernel import weight_dequant
 
-def main(fp8_path, bf16_path):
+def main(fp8_path: str, bf16_path: str) -> None:
     """
     Converts FP8 weights to BF16 and saves the converted weights.
 
@@ -41,7 +41,7 @@ def main(fp8_path, bf16_path):
     fp8_weight_names = []
 
     # Helper function to get tensor from the correct file
-    def get_tensor(tensor_name):
+    def get_tensor(tensor_name: str) -> torch.Tensor:
         """
         Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
 
diff --git a/inference/generate.py b/inference/generate.py
index 7e9bffe..16af7f6 100644
--- a/inference/generate.py
+++ b/inference/generate.py
@@ -11,7 +11,7 @@ from safetensors.torch import load_model
 from model import Transformer, ModelArgs
 
 
-def sample(logits, temperature: float = 1.0):
+def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
     """
     Samples a token from the logits using temperature scaling.
 
diff --git a/inference/kernel.py b/inference/kernel.py
index ba18dca..314b3d4 100644
--- a/inference/kernel.py
+++ b/inference/kernel.py
@@ -7,7 +7,7 @@ from triton import Config
 
 
 @triton.jit
-def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
+def act_quant_kernel(x_ptr: triton.Pointer, y_ptr: triton.Pointer, s_ptr: triton.Pointer, BLOCK_SIZE: tl.constexpr) -> None:
     """
     Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
 
@@ -53,7 +53,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
 
 
 @triton.jit
-def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
+def weight_dequant_kernel(x_ptr: tl.pointer, s_ptr: tl.pointer, y_ptr: tl.pointer, M: int, N: int, BLOCK_SIZE: tl.constexpr) -> None:
     """
     Dequantizes weights using the provided scaling factors and stores the result.
 
@@ -112,12 +112,12 @@ fp8_gemm_configs = [
 
 @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
 @triton.jit
-def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
-                    a_s_ptr, b_s_ptr,
+def fp8_gemm_kernel(a_ptr: tl.tensor, b_ptr: tl.tensor, c_ptr:tl.tensor,
+                    a_s_ptr: tl.tensor, b_s_ptr: tl.tensor,
                     M, N: tl.constexpr, K: tl.constexpr,
                     BLOCK_SIZE_M: tl.constexpr,
                     BLOCK_SIZE_N: tl.constexpr,
-                    BLOCK_SIZE_K: tl.constexpr):
+                    BLOCK_SIZE_K: tl.constexpr) -> None:
     """
     Performs a matrix multiplication operation on FP8 matrices with scaling factors.
 
@@ -167,7 +167,7 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
     tl.store(c_ptrs, c, mask=mask)
 
 
-def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
+def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
     """
     Perform a matrix multiplication using FP8 precision.
 
diff --git a/inference/model.py b/inference/model.py
index c143e97..7c43480 100644
--- a/inference/model.py
+++ b/inference/model.py
@@ -92,7 +92,7 @@ class ParallelEmbedding(nn.Module):
         vocab_size (int): Vocabulary size.
         dim (int): Embedding dimension.
     """
-    def __init__(self, vocab_size: int, dim: int):
+    def __init__(self, vocab_size: int, dim: int) -> None:
         super().__init__()
         self.vocab_size = vocab_size
         self.dim = dim
@@ -173,7 +173,7 @@ class Linear(nn.Module):
     """
     dtype = torch.bfloat16
 
-    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
+    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
         super().__init__()
         self.in_features = in_features
         self.out_features = out_features
@@ -212,7 +212,7 @@ class ColumnParallelLinear(Linear):
         bias (bool): Whether to include a bias term. Defaults to False.
         dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
     """
-    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
+    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
         assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
         self.part_out_features = out_features // world_size
         super().__init__(in_features, self.part_out_features, bias, dtype)
@@ -241,7 +241,7 @@ class RowParallelLinear(Linear):
         bias (bool): Whether to include a bias term. Defaults to False.
         dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
     """
-    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
+    def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
         assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
         self.part_in_features = in_features // world_size
         super().__init__(self.part_in_features, out_features, bias, dtype)
@@ -406,7 +406,7 @@ class MLA(nn.Module):
         v_head_dim (int): Dimensionality of value projections.
         softmax_scale (float): Scaling factor for softmax in attention computation.
     """
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: ModelArgs) -> None:
         super().__init__()
         self.dim = args.dim
         self.n_heads = args.n_heads
@@ -440,7 +440,7 @@ class MLA(nn.Module):
             self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
             self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
 
-    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
+    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
         """
         Forward pass for the Multi-Head Latent Attention (MLA) Layer.
 
@@ -503,7 +503,7 @@ class MLP(nn.Module):
         w2 (nn.Module): Linear layer for hidden-to-output transformation.
         w3 (nn.Module): Additional linear layer for feature transformation.
     """
-    def __init__(self, dim: int, inter_dim: int):
+    def __init__(self, dim: int, inter_dim: int) -> None:
         """
         Initializes the MLP layer.
 
@@ -543,7 +543,7 @@ class Gate(nn.Module):
         weight (torch.nn.Parameter): Learnable weights for the gate.
         bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
     """
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: ModelArgs) -> None:
         """
         Initializes the Gate module.
 
@@ -604,7 +604,7 @@ class Expert(nn.Module):
         w2 (nn.Module): Linear layer for hidden-to-output transformation.
         w3 (nn.Module): Additional linear layer for feature transformation.
     """
-    def __init__(self, dim: int, inter_dim: int):
+    def __init__(self, dim: int, inter_dim: int) -> None:
         """
         Initializes the Expert layer.
 
@@ -643,7 +643,7 @@ class MoE(nn.Module):
         experts (nn.ModuleList): List of expert modules.
         shared_experts (nn.Module): Shared experts applied to all inputs.
     """
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: ModelArgs) -> None:
         """
         Initializes the MoE module.
 
@@ -700,7 +700,7 @@ class Block(nn.Module):
         attn_norm (nn.Module): Layer normalization for attention.
         ffn_norm (nn.Module): Layer normalization for feed-forward network.
     """
-    def __init__(self, layer_id: int, args: ModelArgs):
+    def __init__(self, layer_id: int, args: ModelArgs) -> None:
         """
         Initializes the Transformer block.
 
@@ -744,7 +744,7 @@ class Transformer(nn.Module):
         head (nn.Module): Output projection layer mapping to vocabulary size.
         freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
     """
-    def __init__(self, args: ModelArgs):
+    def __init__(self, args: ModelArgs) -> None:
         """
         Initializes the Transformer model.
 
@@ -766,7 +766,7 @@ class Transformer(nn.Module):
         self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
 
     @torch.inference_mode()
-    def forward(self, tokens: torch.Tensor, start_pos: int = 0):
+    def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
         """
         Forward pass for the Transformer model.