diff --git a/inference/README.md b/inference/README.md new file mode 100644 index 0000000..005f8f7 --- /dev/null +++ b/inference/README.md @@ -0,0 +1,73 @@ +# DeepSeek-V3 Weight File Documentation + + +## BF16 SFT to DeepSeek block-scale weight quantization for inference + +`VLLM` community member `llvm-project/llm-compressor` is working on supporting integrating block-scale quantization [PR#1475](https://github.com/vllm-project/llm-compressor/issues/1475). vLLM and SGLang already support deepseek FP8 Dynamic (act) quantization format. + +DeepSeek weights [README_WEIGHTS.md](../README_WEIGHTS.md) needs to quantize weights statically with block-wise max reduction kernel operator. And we find adding bf16 quantization support with huggingface safetensors are handy. + +To successfully produce fp8 (ocp-e4m3 fmt) quantization model successfully, quite unlike mixtral, we need to ignore `lm_head`, while cast all 61 MoE up, down, gate projection weights to fp8 format: + + +``` +quantize_config = BaseQuantizeConfig( + quant_method="fp8", + activation_scheme="dynamic", + ignore_patterns=[".*lm_head"], +) +``` + +SGLang and vLLM project has taken care of activation dynamic quanization, which is the basically for activations equivalent to 128-group (or 4 x E8M0 32-group) quantization. But block-scale quantization for weights, requires non-K dimension tiling of inputs and outputs. + +For MxN weights, we will produce `ceil(M, BLOCK_SIZE_M) x ciel(N, BLOCK_SIZE_N)` blocks in compute and inverse scalars which can later persisted in kernel. This significantly reduces number of weights needed. + +#### Step by step + +###### 1. Perform quantization + +``` +python bf16_cast_fp8.py \ + --input-bf16-hf-path "${DIST_FS_ROOT}/DeepSeek-sft-bf16/" \ + --output-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate" \ + --input-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-V3-0324" +``` + +`--input-fp8-hf-path` is used to fetch weights scalars in original DeepSeek V3 repo. The file can be generated by + +``` +python bf16_cast_fp8.py \ + --input-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-V3-0324" +``` + +The script creates fp8 safetensors inside `${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate` and `${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate/model.safetensors.index.json` + +###### 2. Copy the following configs from your bf16 checkpoint to the folder + +``` +DEST=${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate + +cp $BF16_CHECK_POINT/config.json +cp $BF16_CHECK_POINT/configuration_deepseek.py ${DEST}/ +cp $BF16_CHECK_POINT/modeling_deepseek.py ${DEST}/ +cp $BF16_CHECK_POINT/tokenizer.json ${DEST}/ +cp $BF16_CHECK_POINT/tokenizer_config.json ${DEST}/ +``` + +Make sure you have add the following dict is added into config.json : + +``` +"quantization_config": { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + "ignored_layers": [".*lm_head"], +} +``` + +We will make simple class to automate the process later. + +## BF16 upgrade training/inference + +This was originally created for inferenc eon non-fp8 capable chips. See details from [fp8_cast_bf16.py](./fp8_cast_bf16.py). diff --git a/inference/auto_fp8/__init__.py b/inference/auto_fp8/__init__.py new file mode 100644 index 0000000..a75f6e9 --- /dev/null +++ b/inference/auto_fp8/__init__.py @@ -0,0 +1,5 @@ +from .config import BaseQuantizeConfig + +__all__ = [ + "BaseQuantizeConfig", +] diff --git a/inference/auto_fp8/config.py b/inference/auto_fp8/config.py new file mode 100644 index 0000000..b1da624 --- /dev/null +++ b/inference/auto_fp8/config.py @@ -0,0 +1,47 @@ +# Adapted from AutoFP8 (deprecated project) +import re +from typing import List, Optional, Tuple + + +class BaseQuantizeConfig: + """Configuration for model quantization. + + Args: + quant_method: Type/precision of quantization method to use. + At the moment, this is just "fp8" which specifically means + the fp8_e4m3 format in pytorch. + activation_scheme: Choice of either "dynamic" or "static" quantization + of activations. If "static", then calibration samples are required + during quantization to produce accurate per-tensor scales for + activations of Linear modules. + ignore_patterns: List of patterns used to ignore layers. If a string + starts with "re:", then everything afterwards is used as python + regex style matching i.e. re.search(), for each Linear layer. + By default, "re:.*lm_head" is included to ignore the embedding + Linear layer usually at the end of decoder LLMs + kv_cache_quant_targets: Tuple of Linear module names to target for + calibration of the output scales for KV cache quantization. + Usually, these should be `("k_proj", "v_proj")`. + """ + + def __init__( + self, + quant_method: str = "fp8", + activation_scheme: str = "static", + ignore_patterns: List[str] = ["re:.*lm_head"], + kv_cache_quant_targets: Optional[Tuple[str]] = None, + ): + if quant_method != "fp8": + raise ValueError("Only FP8 quantization is supported.") + if activation_scheme not in ["static", "dynamic"]: + raise ValueError( + "Invalid activation_scheme. Choose either 'static' or 'dynamic'." + ) + self.quant_method = quant_method + self.activation_scheme = activation_scheme + self.re_ignore_patterns = ignore_patterns + self.ignore_patterns = [ + re.compile(regex_pat, re.VERBOSE) for regex_pat in ignore_patterns + ] + self.kv_cache_quant_targets = kv_cache_quant_targets + self.ignored_layers = [] diff --git a/inference/bf16_cast_fp8.py b/inference/bf16_cast_fp8.py new file mode 100644 index 0000000..07bf706 --- /dev/null +++ b/inference/bf16_cast_fp8.py @@ -0,0 +1,253 @@ +import json +import os +import re +from argparse import ArgumentParser +from glob import glob + +import torch +from auto_fp8 import BaseQuantizeConfig +from kernel import fp8_weight_block_wise_quant +from safetensors.torch import load_file, save_file +from tqdm import tqdm +from transformers import AutoConfig + + +# Helper function to get tensor from the correct file +def has_tensor(weight_map, loaded_files, fp8_path, tensor_name): + """ + Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. + + Args: + tensor_name (str): The name of the tensor to retrieve. + + Returns: + torch.Tensor: The retrieved tensor. + + Raises: + KeyError: If the tensor does not exist in the safetensor file. + """ + file_name = weight_map[tensor_name] + if file_name not in loaded_files: + file_path = os.path.join(fp8_path, file_name) + loaded_files[file_name] = load_file(file_path, device="cuda") + return loaded_files[file_name][tensor_name] + + +def find_ignored(regex_pat, weight_name): + searched = regex_pat.search(weight_name) + if searched is not None: + # print(f"find : {searched.string}") + return searched.string + return None + + +def find_one_ignored(regex_pat_list, weight_name): + for regex_pat in regex_pat_list: + searched = find_ignored(regex_pat, weight_name) + if searched is not None: + return searched + return None + + +quantize_config = BaseQuantizeConfig( + quant_method="fp8", + activation_scheme="dynamic", + ignore_patterns=[".*lm_head"], +) + + +def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None): + """ + Quantize BF16 to FP8 (OCP E4M3) and saves the converted weights. + + This function reads BF16 weights from the specified directory, converts them to FP8 (OCP E4M3), + and saves the converted weights to another specified directory. It also updates the + model index file to reflect the changes. + + Args: + bf16_path (str): The path to the directory containing the BF16 weights and model index file. + fp8_path (str): The path to the directory where the converted FP8 (OCP E4M3) weights will be saved. + + Raises: + KeyError: If a required scale_inv tensor is missing for a weight. + + Notes: + - The function assumes that the BF16 weights are stored in safetensor files. + - The function update the model index file to add references to scale_inv tensors. + """ + # torch.set_default_dtype(torch.bfloat16) + os.makedirs(fp8_path, exist_ok=True) + + model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") + with open(model_index_file, "r") as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + + # Cache for loaded safetensor files + loaded_files = {} + bf16_weight_names = [] + bf16_weight_scale_inv = {} + + safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors"))) + safetensor_files.sort() + for safetensor_file in tqdm(safetensor_files): + file_name = os.path.basename(safetensor_file) + current_state_dict = load_file(safetensor_file, device="cuda") + loaded_files[file_name] = current_state_dict + + new_state_dict = {} + for weight_name, weight in current_state_dict.items(): + if ( + find_one_ignored(quantize_config.ignore_patterns, weight_name) + is not None + ): + # print(f"skipping {weight_name} dtype={weight.dtype}...") + new_state_dict[weight_name] = weight + continue + elif weight.element_size() >= 2: # BF16 / Float weight + + if ( + ref_weights_scale_inv_map is not None + and ref_weights_scale_inv_map.get(weight_name, None) is None + ): + # print(f"skipping {weight_name} dtype={weight.dtype}...") + new_state_dict[weight_name] = weight + continue + + scale_inv_name = f"{weight_name}_scale_inv" + + bf16_weight_names.append(weight_name) + bf16_weight_scale_inv[scale_inv_name] = file_name + + fp8_weight, scale_inv = fp8_weight_block_wise_quant(weight) + new_state_dict[weight_name] = fp8_weight + new_state_dict[scale_inv_name] = scale_inv + else: + # print(f"skipping {weight_name} dtype={weight.dtype} ...") + new_state_dict[weight_name] = weight + + new_safetensor_file = os.path.join(fp8_path, file_name) + save_file(new_state_dict, new_safetensor_file) + + # Memory management: keep only the 2 most recently used files + if len(loaded_files) > 2: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + torch.cuda.empty_cache() + + # Update model index + new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") + + # TODO (yiakwy) : rewrite with dict.update + for weight_name in bf16_weight_names: + scale_inv_name = f"{weight_name}_scale_inv" + weight_map[scale_inv_name] = bf16_weight_scale_inv[scale_inv_name] + + with open(new_model_index_file, "w") as f: + json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) + + +# NOTE (yiakwy) : huggingface library will add some parameters different from deepseek V3, we will modify this later, currently +# we recommend to update config.json manually +def update_quant_model_config(bf16_cast_fp8_path): + cfg = AutoConfig.from_pretrained(bf16_cast_fp8_path) + + static_q_dict = { + "quantization_config": { + "activation_scheme": quantize_config.activation_scheme, + "fmt": "e4m3", + "quant_method": "fp8", + "weight_block_size": [128, 128], + "ignored_layers": quantize_config.re_ignore_patterns, + } + } + + cfg.update(static_q_dict) + cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json")) + + +def read_weight_inv_list(fp8_path): + model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") + with open(model_index_file, "r") as f: + model_index = json.load(f) + weight_map = model_index["weight_map"] + weights_with_scale_inv_map = {} + + loaded_files = {} + fp8_weight_names = [] + + safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) + safetensor_files.sort() + for safetensor_file in tqdm(safetensor_files): + file_name = os.path.basename(safetensor_file) + current_state_dict = load_file(safetensor_file, device="cuda") + loaded_files[file_name] = current_state_dict + + new_state_dict = {} + for weight_name, weight in current_state_dict.items(): + if weight_name.endswith("_scale_inv"): + continue + elif weight.element_size() == 1: # FP8 weight: + scale_inv_name = f"{weight_name}_scale_inv" + try: + # Get scale_inv from the correct file + scale_inv = has_tensor( + weight_map, loaded_files, fp8_path, scale_inv_name + ) + fp8_weight_names.append(weight_name) + weights_with_scale_inv_map[weight_name] = weight_map[scale_inv_name] + except KeyError: + print( + f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion" + ) + new_state_dict[weight_name] = weight + + # Memory management: keep only the 2 most recently used files + if len(loaded_files) > 2: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + torch.cuda.empty_cache() + + weights_with_scale_inv = os.path.join( + fp8_path, "weight_with_scale_inv_map.index.json" + ) + with open(weights_with_scale_inv, "w") as f: + json.dump( + {"metadata": {}, "weight_with_scale_inv_map": weights_with_scale_inv_map}, + f, + indent=2, + ) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input-bf16-hf-path", type=str, required=False) + parser.add_argument("--output-fp8-hf-path", type=str, required=False) + parser.add_argument("--input-fp8-hf-path", type=str, required=False) + parser.add_argument("--input-new-fp8-hf-path", type=str, required=False) + args = parser.parse_args() + + if ( + args.input_fp8_hf_path is not None + and args.input_bf16_hf_path is None + and args.output_fp8_hf_path is None + ): + read_weight_inv_list(args.input_fp8_hf_path) + elif args.input_new_fp8_hf_path is not None: + update_quant_model_config(args.input_new_fp8_hf_path) + else: + assert ( + args.input_bf16_hf_path is not None and args.output_fp8_hf_path is not None + ) + if args.input_fp8_hf_path is not None: + weights_with_scale_inv = os.path.join( + args.input_fp8_hf_path, "weight_with_scale_inv_map.index.json" + ) + with open(weights_with_scale_inv, "r") as f: + model_index = json.load(f) + weight_with_scale_inv_map = model_index["weight_with_scale_inv_map"] + main( + args.input_bf16_hf_path, + args.output_fp8_hf_path, + ref_weights_scale_inv_map=weight_with_scale_inv_map, + ) diff --git a/inference/kernel.py b/inference/kernel.py index ba18dca..68c377c 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -5,9 +5,74 @@ import triton import triton.language as tl from triton import Config +OCP_FP8E4M3_MAX = 448.0 + +FP8_MAX = OCP_FP8E4M3_MAX + @triton.jit -def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): +def fp8_weight_block_wise_quant_kernel( + x_ptr, + y_ptr, + s_ptr, + M, + N, + SCALED_BLOCK_SIZE_M: tl.constexpr, + SCALED_BLOCK_SIZE_N: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, SCALED_BLOCK_SIZE_M) + offs_m = pid_m * SCALED_BLOCK_SIZE_M + tl.arange(0, SCALED_BLOCK_SIZE_M) + offs_n = pid_n * SCALED_BLOCK_SIZE_N + tl.arange(0, SCALED_BLOCK_SIZE_N) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.max(tl.abs(x)) / FP8_MAX + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y, mask=mask) + tl.store(s_ptr + pid_m * n + pid_n, s) + pass + + +def fp8_weight_block_wise_quant( + x: torch.Tensor, scaled_block_size_m: int = 128, scaled_block_size_n: int = 128 +): + assert x.is_contiguous(), "Input tensor must be contiguous" + assert x.dim() == 2, "Input tensor must have 2 dimensions" + # assert x.size(0) % scaled_block_size_m == 0 and x.size(1) % scaled_block_size_n == 0, \ + # f"Dimensions of x must be divisible by scaled block_size (scale_block_size_m={scaled_block_size_m}x{scaled_block_size_n})" + M, N = x.size() + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty( + triton.cdiv(M, scaled_block_size_m), + triton.cdiv(N, scaled_block_size_n), + dtype=torch.float32, + ) + grid = lambda meta: ( + triton.cdiv(M, meta["SCALED_BLOCK_SIZE_M"]), + triton.cdiv(N, meta["SCALED_BLOCK_SIZE_N"]), + ) + fp8_weight_block_wise_quant_kernel[grid]( + x, + y, + s, + M, + N, + SCALED_BLOCK_SIZE_M=scaled_block_size_m, + SCALED_BLOCK_SIZE_N=scaled_block_size_n, + FP8_MAX=FP8_MAX, + ) + return y, s + pass + + +@triton.jit +def act_quant_kernel( + x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, FP8_MAX: tl.constexpr +): """ Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. @@ -23,14 +88,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448. + s = tl.max(tl.abs(x)) / FP8_MAX y = x / s y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) tl.store(s_ptr + pid, s) -def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def act_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantizes the input tensor `x` using block-wise quantization. @@ -43,12 +110,14 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor - The quantized tensor with dtype `torch.float8_e4m3fn`. - A tensor of scaling factors with dtype `torch.float32`. """ - assert x.is_contiguous(), 'Input tensor must be contiguous' - assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' + assert x.is_contiguous(), "Input tensor must be contiguous" + assert ( + x.size(-1) % block_size == 0 + ), f"Last dimension size must be divisible by block_size (block_size={block_size})" y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) - grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) - act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) + act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, FP8_MAX=FP8_MAX) return y, s @@ -81,7 +150,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offs, y, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 +) -> torch.Tensor: """ Dequantizes the given weight tensor using the provided scale tensor. @@ -96,28 +167,45 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t Raises: AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. """ - assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' - assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' + assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous" + assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions" M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y fp8_gemm_configs = [ - Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) - for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] + Config( + {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] ] -@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) + +@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) @triton.jit -def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, - a_s_ptr, b_s_ptr, - M, N: tl.constexpr, K: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): +def fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): """ Performs a matrix multiplication operation on FP8 matrices with scaling factors. @@ -180,12 +268,17 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten Returns: torch.Tensor: The result of the matrix multiplication. """ - assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' - assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' + assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" + assert ( + a_s.is_contiguous() and b_s.is_contiguous() + ), "Scaling factor tensors must be contiguous" K = a.size(-1) M = a.numel() // K N = b.size(0) c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) return c