add support block-wise quant from bf16

This commit is contained in:
yiakwy-xpu-ml-framework-team 2025-07-01 07:59:03 +08:00
parent f6e34dd267
commit 44c403f0d8
4 changed files with 422 additions and 23 deletions

View File

@ -0,0 +1,5 @@
from .config import BaseQuantizeConfig
__all__ = [
"BaseQuantizeConfig",
]

View File

@ -0,0 +1,47 @@
# Adapted from AutoFP8 (deprecated project)
import re
from typing import List, Optional, Tuple
class BaseQuantizeConfig:
"""Configuration for model quantization.
Args:
quant_method: Type/precision of quantization method to use.
At the moment, this is just "fp8" which specifically means
the fp8_e4m3 format in pytorch.
activation_scheme: Choice of either "dynamic" or "static" quantization
of activations. If "static", then calibration samples are required
during quantization to produce accurate per-tensor scales for
activations of Linear modules.
ignore_patterns: List of patterns used to ignore layers. If a string
starts with "re:", then everything afterwards is used as python
regex style matching i.e. re.search(), for each Linear layer.
By default, "re:.*lm_head" is included to ignore the embedding
Linear layer usually at the end of decoder LLMs
kv_cache_quant_targets: Tuple of Linear module names to target for
calibration of the output scales for KV cache quantization.
Usually, these should be `("k_proj", "v_proj")`.
"""
def __init__(
self,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = ["re:.*lm_head"],
kv_cache_quant_targets: Optional[Tuple[str]] = None,
):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
if activation_scheme not in ["static", "dynamic"]:
raise ValueError(
"Invalid activation_scheme. Choose either 'static' or 'dynamic'."
)
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.re_ignore_patterns = ignore_patterns
self.ignore_patterns = [
re.compile(regex_pat, re.VERBOSE) for regex_pat in ignore_patterns
]
self.kv_cache_quant_targets = kv_cache_quant_targets
self.ignored_layers = []

254
inference/bf16_cast_fp8.py Normal file
View File

@ -0,0 +1,254 @@
import json
import os
import re
from argparse import ArgumentParser
from glob import glob
import torch
from auto_fp8 import BaseQuantizeConfig
from kernel import fp8_weight_block_wise_quant
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import AutoConfig
# Helper function to get tensor from the correct file
def has_tensor(weight_map, loaded_files, fp8_path, tensor_name):
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
Args:
tensor_name (str): The name of the tensor to retrieve.
Returns:
torch.Tensor: The retrieved tensor.
Raises:
KeyError: If the tensor does not exist in the safetensor file.
"""
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]
def find_ignored(regex_pat, weight_name):
searched = regex_pat.search(weight_name)
if searched is not None:
print(f"find : {searched.string}")
return searched.string
return None
def find_one_ignored(regex_pat_list, weight_name):
for regex_pat in regex_pat_list:
searched = find_ignored(regex_pat, weight_name)
if searched is not None:
return searched
return None
quantize_config = BaseQuantizeConfig(
quant_method="fp8",
activation_scheme="dynamic",
ignore_patterns=[".*lm_head", ".*gate"],
)
def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
"""
Quantize BF16 to FP8 (OCP E4M3) and saves the converted weights.
This function reads BF16 weights from the specified directory, converts them to FP8 (OCP E4M3),
and saves the converted weights to another specified directory. It also updates the
model index file to reflect the changes.
Args:
bf16_path (str): The path to the directory containing the BF16 weights and model index file.
fp8_path (str): The path to the directory where the converted FP8 (OCP E4M3) weights will be saved.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
Notes:
- The function assumes that the BF16 weights are stored in safetensor files.
- The function update the model index file to add references to scale_inv tensors.
"""
# torch.set_default_dtype(torch.bfloat16)
os.makedirs(fp8_path, exist_ok=True)
model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
bf16_weight_names = []
safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if (
find_one_ignored(quantize_config.ignore_patterns, weight_name)
is not None
):
continue
elif weight.element_size() == 2: # BF16 weight
if (
ref_weights_scale_inv_map is not None
and ref_weights_scale_inv_map.get(weight_name, None) is None
):
print(f"skipping {weight_name} ...")
continue
pass
scale_inv_name = f"{weight_name}_scale_inv"
bf16_weight_names.append(weight_name)
fp8_weight, scale_inv = fp8_weight_block_wise_quant(weight)
new_state_dict[weight_name] = fp8_weight
new_state_dict[scale_inv_name] = scale_inv
else:
new_state_dict[weight_name] = weight
pass
new_safetensor_file = os.path.join(fp8_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index
new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
for weight_name in bf16_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.insert(scale_inv_name)
pass
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
pass
def update_quant_model_config(bf16_cast_fp8_path):
cfg = AutoConfig.from_pretrained(bf16_cast_fp8_path)
static_q_dict = {
"quantization_config": {
"activation_scheme": quantize_config.activation_scheme,
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
"ignored_layers": quantize_config.re_ignore_patterns,
}
}
cfg.update(static_q_dict)
cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json.bak"))
pass
def read_weight_inv_list(fp8_path):
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
weights_with_scale_inv_map = {}
loaded_files = {}
fp8_weight_names = []
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight:
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = has_tensor(
weight_map, loaded_files, fp8_path, scale_inv_name
)
fp8_weight_names.append(weight_name)
weights_with_scale_inv_map[weight_name] = weight_map[scale_inv_name]
except KeyError:
print(
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
)
new_state_dict[weight_name] = weight
pass
pass
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
pass
weights_with_scale_inv = os.path.join(
fp8_path, "weight_with_scale_inv_map.index.json"
)
with open(weights_with_scale_inv, "w") as f:
json.dump(
{"metadata": {}, "weight_with_scale_inv_map": weights_with_scale_inv_map},
f,
indent=2,
)
pass
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-bf16-hf-path", type=str, required=False)
parser.add_argument("--output-fp8-hf-path", type=str, required=False)
parser.add_argument("--input-fp8-hf-path", type=str, required=False)
parser.add_argument("--input-new-fp8-hf-path", type=str, required=False)
args = parser.parse_args()
if (
args.input_fp8_hf_path is not None
and args.input_bf16_hf_path is None
and args.output_fp8_hf_path is None
):
read_weight_inv_list(args.input_fp8_hf_path)
elif args.input_new_fp8_hf_path is not None:
update_quant_model_config(args.input_new_fp8_hf_path)
pass
else:
assert (
args.input_bf16_hf_path is not None and args.output_fp8_hf_path is not None
)
if args.input_fp8_hf_path is not None:
weights_with_scale_inv = os.path.join(
args.input_fp8_hf_path, "weight_with_scale_inv_map.index.json"
)
with open(weights_with_scale_inv, "r") as f:
model_index = json.load(f)
pass
weight_with_scale_inv_map = model_index["weight_with_scale_inv_map"]
pass
main(
args.input_bf16_hf_path,
args.output_fp8_hf_path,
ref_weights_scale_inv_map=weight_with_scale_inv_map,
)

View File

@ -5,9 +5,74 @@ import triton
import triton.language as tl
from triton import Config
OCP_FP8E4M3_MAX = 448.0
FP8_MAX = OCP_FP8E4M3_MAX
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
def fp8_weight_block_wise_quant_kernel(
x_ptr,
y_ptr,
s_ptr,
M,
N,
SCALED_BLOCK_SIZE_M: tl.constexpr,
SCALED_BLOCK_SIZE_N: tl.constexpr,
FP8_MAX: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, SCALED_BLOCK_SIZE_M)
offs_m = pid_m * SCALED_BLOCK_SIZE_M + tl.arange(0, SCALED_BLOCK_SIZE_M)
offs_n = pid_n * SCALED_BLOCK_SIZE_N + tl.arange(0, SCALED_BLOCK_SIZE_N)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.max(tl.abs(x)) / FP8_MAX
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y, mask=mask)
tl.store(s_ptr + pid_m * n + pid_n, s)
pass
def fp8_weight_block_wise_quant(
x: torch.Tensor, scaled_block_size_m: int = 128, scaled_block_size_n: int = 128
):
assert x.is_contiguous(), "Input tensor must be contiguous"
assert x.dim() == 2, "Input tensor must have 2 dimensions"
# assert x.size(0) % scaled_block_size_m == 0 and x.size(1) % scaled_block_size_n == 0, \
# f"Dimensions of x must be divisible by scaled block_size (scale_block_size_m={scaled_block_size_m}x{scaled_block_size_n})"
M, N = x.size()
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(
triton.cdiv(M, scaled_block_size_m),
triton.cdiv(N, scaled_block_size_n),
dtype=torch.float32,
)
grid = lambda meta: (
triton.cdiv(M, meta["SCALED_BLOCK_SIZE_M"]),
triton.cdiv(N, meta["SCALED_BLOCK_SIZE_N"]),
)
fp8_weight_block_wise_quant_kernel[grid](
x,
y,
s,
M,
N,
SCALED_BLOCK_SIZE_M=scaled_block_size_m,
SCALED_BLOCK_SIZE_N=scaled_block_size_n,
FP8_MAX=FP8_MAX,
)
return y, s
pass
@triton.jit
def act_quant_kernel(
x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, FP8_MAX: tl.constexpr
):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
@ -23,14 +88,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
s = tl.max(tl.abs(x)) / FP8_MAX
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
def act_quant(
x: torch.Tensor, block_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
@ -43,12 +110,14 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), 'Input tensor must be contiguous'
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
assert x.is_contiguous(), "Input tensor must be contiguous"
assert (
x.size(-1) % block_size == 0
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, FP8_MAX=FP8_MAX)
return y, s
@ -81,7 +150,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
def weight_dequant(
x: torch.Tensor, s: torch.Tensor, block_size: int = 128
) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
@ -96,28 +167,45 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_SIZE"]),
triton.cdiv(N, meta["BLOCK_SIZE"]),
)
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
Config(
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
num_stages=num_stages,
num_warps=8,
)
for block_m in [16, 32, 64]
for block_n in [32, 64, 128]
for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
def fp8_gemm_kernel(
a_ptr,
b_ptr,
c_ptr,
a_s_ptr,
b_s_ptr,
M,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
@ -180,12 +268,17 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
assert (
a_s.is_contiguous() and b_s.is_contiguous()
), "Scaling factor tensors must be contiguous"
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c