From 06231633430c449a1f6a8c0b5a1d6f99860b8eb2 Mon Sep 17 00:00:00 2001
From: yzlnew <yzlnew@gmail.com>
Date: Tue, 25 Mar 2025 19:38:45 +0800
Subject: [PATCH 1/2] add bf16 to fp8

---
 inference/bf16_cast_fp8.py | 97 ++++++++++++++++++++++++++++++++++++++
 inference/kernel.py        | 29 ++++++++++++
 2 files changed, 126 insertions(+)
 create mode 100644 inference/bf16_cast_fp8.py

diff --git a/inference/bf16_cast_fp8.py b/inference/bf16_cast_fp8.py
new file mode 100644
index 0000000..86f5749
--- /dev/null
+++ b/inference/bf16_cast_fp8.py
@@ -0,0 +1,97 @@
+import os
+import json
+import re
+from argparse import ArgumentParser
+from glob import glob
+from tqdm import tqdm
+
+import torch
+from safetensors.torch import load_file, save_file
+
+from kernel import weight_quant
+
+# Layers that should not be quantized (remain in BF16)
+SKIP_QUANT_PATTERNS = [
+    r".*\.layernorm\.weight$",
+    r".*\.norm\.weight$",
+    r".*input_layernorm\.weight$",
+    r".*post_attention_layernorm\.weight$",
+    r".*\.kv_a_layernorm\.weight$",
+    r".*\.q_a_layernorm\.weight$",
+    r".*\.embed_tokens\.weight$",
+    r".*\.head\.weight$",
+    r".*lm_head\.weight$",
+    r".*\.eh_proj\.weight$",
+    r".*\.gate\.e_score_correction_bias$",
+    r".*\.gate\.weight$"
+]
+
+def should_skip_quantization(weight_name):
+    """Check if weight name matches any pattern in the skip list"""
+    return any(re.match(pattern, weight_name) for pattern in SKIP_QUANT_PATTERNS)
+
+def main(bf16_path, fp8_path):
+    torch.set_default_dtype(torch.bfloat16)
+    os.makedirs(fp8_path, exist_ok=True)
+
+    # Get list of safetensor files
+    safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
+    safetensor_files.sort()
+
+    # Load model index if it exists
+    model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
+    if os.path.exists(model_index_file):
+        with open(model_index_file, "r") as f:
+            model_index = json.load(f)
+        weight_map = model_index["weight_map"]
+    else:
+        # Create a new weight map if there's no index file
+        weight_map = {}
+
+    # Cache for loaded safetensor files
+    loaded_files = {}
+    fp8_weight_names = []
+
+    for safetensor_file in tqdm(safetensor_files):
+        file_name = os.path.basename(safetensor_file)
+        current_state_dict = load_file(safetensor_file, device="cuda")
+        loaded_files[file_name] = current_state_dict
+
+        new_state_dict = {}
+        for weight_name, weight in current_state_dict.items():
+            # Skip weights that should not be quantized
+            if should_skip_quantization(weight_name) or weight.dim() != 2:
+                new_state_dict[weight_name] = weight
+            else:
+                # Quantize weights to FP8
+                fp8_weight, scale_inv = weight_quant(weight)
+                new_state_dict[weight_name] = fp8_weight
+                scale_inv_name = f"{weight_name}_scale_inv"
+                new_state_dict[scale_inv_name] = scale_inv
+                fp8_weight_names.append(weight_name)
+
+                # Update weight map
+                if weight_name in weight_map:
+                    weight_map[scale_inv_name] = file_name
+
+        new_safetensor_file = os.path.join(fp8_path, file_name)
+        save_file(new_state_dict, new_safetensor_file)
+
+        # Memory management: keep only the 2 most recently used files
+        if len(loaded_files) > 2:
+            oldest_file = next(iter(loaded_files))
+            del loaded_files[oldest_file]
+            torch.cuda.empty_cache()
+
+    # Update model index
+    new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
+    with open(new_model_index_file, "w") as f:
+        json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("--input-bf16-hf-path", type=str, required=True)
+    parser.add_argument("--output-fp8-hf-path", type=str, required=True)
+    args = parser.parse_args()
+    main(args.input_bf16_hf_path, args.output_fp8_hf_path)
\ No newline at end of file
diff --git a/inference/kernel.py b/inference/kernel.py
index ae907ad..b90edf8 100644
--- a/inference/kernel.py
+++ b/inference/kernel.py
@@ -105,6 +105,35 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
     return y
 
 
+@triton.jit
+def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
+    pid_m = tl.program_id(axis=0)
+    pid_n = tl.program_id(axis=1)
+    n = tl.cdiv(N, BLOCK_SIZE)
+    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+    offs = offs_m[:, None] * N + offs_n[None, :]
+    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
+    x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
+    max_val = tl.max(tl.abs(x))
+    s = max_val / 448.0  # Same scaling as in act_quant
+    y = x / s
+    y = y.to(y_ptr.dtype.element_ty)
+    tl.store(y_ptr + offs, y, mask=mask)
+    tl.store(s_ptr + pid_m * n + pid_n, s)
+
+
+def weight_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
+    assert x.is_contiguous()
+    assert x.dim() == 2
+    M, N = x.size()
+    y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
+    s = x.new_empty(triton.cdiv(M, block_size), triton.cdiv(N, block_size), dtype=torch.float32)
+    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
+    weight_quant_kernel[grid](x, y, s, M, N, BLOCK_SIZE=block_size)
+    return y, s
+
+
 fp8_gemm_configs = [
     Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
     for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]

From 630769360a6ebcf79d270e0944f3aa2dfe888693 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=BB=84=E7=9F=B3?= <yzlnew@gmail.com>
Date: Sun, 6 Apr 2025 00:31:32 +0800
Subject: [PATCH 2/2] add docstring

---
 inference/kernel.py | 29 +++++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/inference/kernel.py b/inference/kernel.py
index b90edf8..f2277dc 100644
--- a/inference/kernel.py
+++ b/inference/kernel.py
@@ -107,6 +107,20 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
 
 @triton.jit
 def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
+    """
+    Quantizes weights in blocks and computes scaling factors for each block.
+
+    Args:
+        x_ptr (tl.pointer): Pointer to the input weights tensor.
+        y_ptr (tl.pointer): Pointer to the output buffer for quantized weights.
+        s_ptr (tl.pointer): Pointer to the output buffer for scaling factors.
+        M (int): Number of rows in the weight matrix.
+        N (int): Number of columns in the weight matrix.
+        BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
+
+    Returns:
+        None
+    """
     pid_m = tl.program_id(axis=0)
     pid_n = tl.program_id(axis=1)
     n = tl.cdiv(N, BLOCK_SIZE)
@@ -124,6 +138,21 @@ def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
 
 
 def weight_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Quantizes a weight tensor using block-wise quantization.
+
+    Args:
+        x (torch.Tensor): The input weight tensor of shape (M, N) to be quantized.
+        block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
+            - The quantized tensor with dtype `torch.float8_e4m3fn`.
+            - A tensor of scaling factors with dtype `torch.float32`.
+
+    Raises:
+        AssertionError: If `x` is not contiguous or if its dimensions are not 2.
+    """
     assert x.is_contiguous()
     assert x.dim() == 2
     M, N = x.size()