This commit is contained in:
Yiakwy 2025-07-01 12:04:58 +00:00 committed by GitHub
commit 94590b9924
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 494 additions and 23 deletions

73
inference/README.md Normal file
View File

@ -0,0 +1,73 @@
# DeepSeek-V3 Weight File Documentation
## BF16 SFT to DeepSeek block-scale weight quantization for inference
`VLLM` community member `llvm-project/llm-compressor` is working on supporting integrating block-scale quantization [PR#1475](https://github.com/vllm-project/llm-compressor/issues/1475). vLLM and SGLang already support deepseek FP8 Dynamic (act) quantization format.
DeepSeek weights [README_WEIGHTS.md](../README_WEIGHTS.md) needs to quantize weights statically with block-wise max reduction kernel operator. And we find adding bf16 quantization support with huggingface safetensors are handy.
To successfully produce fp8 (ocp-e4m3 fmt) quantization model successfully, quite unlike mixtral, we need to ignore `lm_head`, while cast all 61 MoE up, down, gate projection weights to fp8 format:
```
quantize_config = BaseQuantizeConfig(
quant_method="fp8",
activation_scheme="dynamic",
ignore_patterns=[".*lm_head"],
)
```
SGLang and vLLM project has taken care of activation dynamic quanization, which is the basically for activations equivalent to 128-group (or 4 x E8M0 32-group) quantization. But block-scale quantization for weights, requires non-K dimension tiling of inputs and outputs.
For MxN weights, we will produce `ceil(M, BLOCK_SIZE_M) x ciel(N, BLOCK_SIZE_N)` blocks in compute and inverse scalars which can later persisted in kernel. This significantly reduces number of weights needed.
#### Step by step
###### 1. Perform quantization
```
python bf16_cast_fp8.py \
--input-bf16-hf-path "${DIST_FS_ROOT}/DeepSeek-sft-bf16/" \
--output-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate" \
--input-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-V3-0324"
```
`--input-fp8-hf-path` is used to fetch weights scalars in original DeepSeek V3 repo. The file can be generated by
```
python bf16_cast_fp8.py \
--input-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-V3-0324"
```
The script creates fp8 safetensors inside `${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate` and `${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate/model.safetensors.index.json`
###### 2. Copy the following configs from your bf16 checkpoint to the folder
```
DEST=${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate
cp $BF16_CHECK_POINT/config.json
cp $BF16_CHECK_POINT/configuration_deepseek.py ${DEST}/
cp $BF16_CHECK_POINT/modeling_deepseek.py ${DEST}/
cp $BF16_CHECK_POINT/tokenizer.json ${DEST}/
cp $BF16_CHECK_POINT/tokenizer_config.json ${DEST}/
```
Make sure you have add the following dict is added into config.json :
```
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
"ignored_layers": [".*lm_head"],
}
```
We will make simple class to automate the process later.
## BF16 upgrade training/inference
This was originally created for inferenc eon non-fp8 capable chips. See details from [fp8_cast_bf16.py](./fp8_cast_bf16.py).

View File

@ -0,0 +1,5 @@
from .config import BaseQuantizeConfig
__all__ = [
"BaseQuantizeConfig",
]

View File

@ -0,0 +1,47 @@
# Adapted from AutoFP8 (deprecated project)
import re
from typing import List, Optional, Tuple
class BaseQuantizeConfig:
"""Configuration for model quantization.
Args:
quant_method: Type/precision of quantization method to use.
At the moment, this is just "fp8" which specifically means
the fp8_e4m3 format in pytorch.
activation_scheme: Choice of either "dynamic" or "static" quantization
of activations. If "static", then calibration samples are required
during quantization to produce accurate per-tensor scales for
activations of Linear modules.
ignore_patterns: List of patterns used to ignore layers. If a string
starts with "re:", then everything afterwards is used as python
regex style matching i.e. re.search(), for each Linear layer.
By default, "re:.*lm_head" is included to ignore the embedding
Linear layer usually at the end of decoder LLMs
kv_cache_quant_targets: Tuple of Linear module names to target for
calibration of the output scales for KV cache quantization.
Usually, these should be `("k_proj", "v_proj")`.
"""
def __init__(
self,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = ["re:.*lm_head"],
kv_cache_quant_targets: Optional[Tuple[str]] = None,
):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
if activation_scheme not in ["static", "dynamic"]:
raise ValueError(
"Invalid activation_scheme. Choose either 'static' or 'dynamic'."
)
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.re_ignore_patterns = ignore_patterns
self.ignore_patterns = [
re.compile(regex_pat, re.VERBOSE) for regex_pat in ignore_patterns
]
self.kv_cache_quant_targets = kv_cache_quant_targets
self.ignored_layers = []

253
inference/bf16_cast_fp8.py Normal file
View File

@ -0,0 +1,253 @@
import json
import os
import re
from argparse import ArgumentParser
from glob import glob
import torch
from auto_fp8 import BaseQuantizeConfig
from kernel import fp8_weight_block_wise_quant
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import AutoConfig
# Helper function to get tensor from the correct file
def has_tensor(weight_map, loaded_files, fp8_path, tensor_name):
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
Args:
tensor_name (str): The name of the tensor to retrieve.
Returns:
torch.Tensor: The retrieved tensor.
Raises:
KeyError: If the tensor does not exist in the safetensor file.
"""
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]
def find_ignored(regex_pat, weight_name):
searched = regex_pat.search(weight_name)
if searched is not None:
# print(f"find : {searched.string}")
return searched.string
return None
def find_one_ignored(regex_pat_list, weight_name):
for regex_pat in regex_pat_list:
searched = find_ignored(regex_pat, weight_name)
if searched is not None:
return searched
return None
quantize_config = BaseQuantizeConfig(
quant_method="fp8",
activation_scheme="dynamic",
ignore_patterns=[".*lm_head"],
)
def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
"""
Quantize BF16 to FP8 (OCP E4M3) and saves the converted weights.
This function reads BF16 weights from the specified directory, converts them to FP8 (OCP E4M3),
and saves the converted weights to another specified directory. It also updates the
model index file to reflect the changes.
Args:
bf16_path (str): The path to the directory containing the BF16 weights and model index file.
fp8_path (str): The path to the directory where the converted FP8 (OCP E4M3) weights will be saved.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
Notes:
- The function assumes that the BF16 weights are stored in safetensor files.
- The function update the model index file to add references to scale_inv tensors.
"""
# torch.set_default_dtype(torch.bfloat16)
os.makedirs(fp8_path, exist_ok=True)
model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
bf16_weight_names = []
bf16_weight_scale_inv = {}
safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if (
find_one_ignored(quantize_config.ignore_patterns, weight_name)
is not None
):
# print(f"skipping {weight_name} dtype={weight.dtype}...")
new_state_dict[weight_name] = weight
continue
elif weight.element_size() >= 2: # BF16 / Float weight
if (
ref_weights_scale_inv_map is not None
and ref_weights_scale_inv_map.get(weight_name, None) is None
):
# print(f"skipping {weight_name} dtype={weight.dtype}...")
new_state_dict[weight_name] = weight
continue
scale_inv_name = f"{weight_name}_scale_inv"
bf16_weight_names.append(weight_name)
bf16_weight_scale_inv[scale_inv_name] = file_name
fp8_weight, scale_inv = fp8_weight_block_wise_quant(weight)
new_state_dict[weight_name] = fp8_weight
new_state_dict[scale_inv_name] = scale_inv
else:
# print(f"skipping {weight_name} dtype={weight.dtype} ...")
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(fp8_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index
new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
# TODO (yiakwy) : rewrite with dict.update
for weight_name in bf16_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
weight_map[scale_inv_name] = bf16_weight_scale_inv[scale_inv_name]
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
# NOTE (yiakwy) : huggingface library will add some parameters different from deepseek V3, we will modify this later, currently
# we recommend to update config.json manually
def update_quant_model_config(bf16_cast_fp8_path):
cfg = AutoConfig.from_pretrained(bf16_cast_fp8_path)
static_q_dict = {
"quantization_config": {
"activation_scheme": quantize_config.activation_scheme,
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [128, 128],
"ignored_layers": quantize_config.re_ignore_patterns,
}
}
cfg.update(static_q_dict)
cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json"))
def read_weight_inv_list(fp8_path):
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
weights_with_scale_inv_map = {}
loaded_files = {}
fp8_weight_names = []
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight:
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = has_tensor(
weight_map, loaded_files, fp8_path, scale_inv_name
)
fp8_weight_names.append(weight_name)
weights_with_scale_inv_map[weight_name] = weight_map[scale_inv_name]
except KeyError:
print(
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
)
new_state_dict[weight_name] = weight
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
weights_with_scale_inv = os.path.join(
fp8_path, "weight_with_scale_inv_map.index.json"
)
with open(weights_with_scale_inv, "w") as f:
json.dump(
{"metadata": {}, "weight_with_scale_inv_map": weights_with_scale_inv_map},
f,
indent=2,
)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-bf16-hf-path", type=str, required=False)
parser.add_argument("--output-fp8-hf-path", type=str, required=False)
parser.add_argument("--input-fp8-hf-path", type=str, required=False)
parser.add_argument("--input-new-fp8-hf-path", type=str, required=False)
args = parser.parse_args()
if (
args.input_fp8_hf_path is not None
and args.input_bf16_hf_path is None
and args.output_fp8_hf_path is None
):
read_weight_inv_list(args.input_fp8_hf_path)
elif args.input_new_fp8_hf_path is not None:
update_quant_model_config(args.input_new_fp8_hf_path)
else:
assert (
args.input_bf16_hf_path is not None and args.output_fp8_hf_path is not None
)
if args.input_fp8_hf_path is not None:
weights_with_scale_inv = os.path.join(
args.input_fp8_hf_path, "weight_with_scale_inv_map.index.json"
)
with open(weights_with_scale_inv, "r") as f:
model_index = json.load(f)
weight_with_scale_inv_map = model_index["weight_with_scale_inv_map"]
main(
args.input_bf16_hf_path,
args.output_fp8_hf_path,
ref_weights_scale_inv_map=weight_with_scale_inv_map,
)

View File

@ -5,9 +5,74 @@ import triton
import triton.language as tl
from triton import Config
OCP_FP8E4M3_MAX = 448.0
FP8_MAX = OCP_FP8E4M3_MAX
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
def fp8_weight_block_wise_quant_kernel(
x_ptr,
y_ptr,
s_ptr,
M,
N,
SCALED_BLOCK_SIZE_M: tl.constexpr,
SCALED_BLOCK_SIZE_N: tl.constexpr,
FP8_MAX: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, SCALED_BLOCK_SIZE_M)
offs_m = pid_m * SCALED_BLOCK_SIZE_M + tl.arange(0, SCALED_BLOCK_SIZE_M)
offs_n = pid_n * SCALED_BLOCK_SIZE_N + tl.arange(0, SCALED_BLOCK_SIZE_N)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.max(tl.abs(x)) / FP8_MAX
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y, mask=mask)
tl.store(s_ptr + pid_m * n + pid_n, s)
pass
def fp8_weight_block_wise_quant(
x: torch.Tensor, scaled_block_size_m: int = 128, scaled_block_size_n: int = 128
):
assert x.is_contiguous(), "Input tensor must be contiguous"
assert x.dim() == 2, "Input tensor must have 2 dimensions"
# assert x.size(0) % scaled_block_size_m == 0 and x.size(1) % scaled_block_size_n == 0, \
# f"Dimensions of x must be divisible by scaled block_size (scale_block_size_m={scaled_block_size_m}x{scaled_block_size_n})"
M, N = x.size()
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(
triton.cdiv(M, scaled_block_size_m),
triton.cdiv(N, scaled_block_size_n),
dtype=torch.float32,
)
grid = lambda meta: (
triton.cdiv(M, meta["SCALED_BLOCK_SIZE_M"]),
triton.cdiv(N, meta["SCALED_BLOCK_SIZE_N"]),
)
fp8_weight_block_wise_quant_kernel[grid](
x,
y,
s,
M,
N,
SCALED_BLOCK_SIZE_M=scaled_block_size_m,
SCALED_BLOCK_SIZE_N=scaled_block_size_n,
FP8_MAX=FP8_MAX,
)
return y, s
pass
@triton.jit
def act_quant_kernel(
x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, FP8_MAX: tl.constexpr
):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
@ -23,14 +88,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
s = tl.max(tl.abs(x)) / FP8_MAX
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
def act_quant(
x: torch.Tensor, block_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
@ -43,12 +110,14 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), 'Input tensor must be contiguous'
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
assert x.is_contiguous(), "Input tensor must be contiguous"
assert (
x.size(-1) % block_size == 0
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, FP8_MAX=FP8_MAX)
return y, s
@ -81,7 +150,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
def weight_dequant(
x: torch.Tensor, s: torch.Tensor, block_size: int = 128
) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
@ -96,28 +167,45 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_SIZE"]),
triton.cdiv(N, meta["BLOCK_SIZE"]),
)
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
Config(
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
num_stages=num_stages,
num_warps=8,
)
for block_m in [16, 32, 64]
for block_n in [32, 64, 128]
for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
def fp8_gemm_kernel(
a_ptr,
b_ptr,
c_ptr,
a_s_ptr,
b_s_ptr,
M,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
BLOCK_SIZE_K: tl.constexpr,
):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
@ -180,12 +268,17 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
assert (
a_s.is_contiguous() and b_s.is_contiguous()
), "Scaling factor tensors must be contiguous"
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c