mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-07-05 07:51:38 -04:00
Merge 7813704f78
into f6e34dd267
This commit is contained in:
commit
94590b9924
73
inference/README.md
Normal file
73
inference/README.md
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
# DeepSeek-V3 Weight File Documentation
|
||||||
|
|
||||||
|
|
||||||
|
## BF16 SFT to DeepSeek block-scale weight quantization for inference
|
||||||
|
|
||||||
|
`VLLM` community member `llvm-project/llm-compressor` is working on supporting integrating block-scale quantization [PR#1475](https://github.com/vllm-project/llm-compressor/issues/1475). vLLM and SGLang already support deepseek FP8 Dynamic (act) quantization format.
|
||||||
|
|
||||||
|
DeepSeek weights [README_WEIGHTS.md](../README_WEIGHTS.md) needs to quantize weights statically with block-wise max reduction kernel operator. And we find adding bf16 quantization support with huggingface safetensors are handy.
|
||||||
|
|
||||||
|
To successfully produce fp8 (ocp-e4m3 fmt) quantization model successfully, quite unlike mixtral, we need to ignore `lm_head`, while cast all 61 MoE up, down, gate projection weights to fp8 format:
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
quantize_config = BaseQuantizeConfig(
|
||||||
|
quant_method="fp8",
|
||||||
|
activation_scheme="dynamic",
|
||||||
|
ignore_patterns=[".*lm_head"],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
SGLang and vLLM project has taken care of activation dynamic quanization, which is the basically for activations equivalent to 128-group (or 4 x E8M0 32-group) quantization. But block-scale quantization for weights, requires non-K dimension tiling of inputs and outputs.
|
||||||
|
|
||||||
|
For MxN weights, we will produce `ceil(M, BLOCK_SIZE_M) x ciel(N, BLOCK_SIZE_N)` blocks in compute and inverse scalars which can later persisted in kernel. This significantly reduces number of weights needed.
|
||||||
|
|
||||||
|
#### Step by step
|
||||||
|
|
||||||
|
###### 1. Perform quantization
|
||||||
|
|
||||||
|
```
|
||||||
|
python bf16_cast_fp8.py \
|
||||||
|
--input-bf16-hf-path "${DIST_FS_ROOT}/DeepSeek-sft-bf16/" \
|
||||||
|
--output-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate" \
|
||||||
|
--input-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-V3-0324"
|
||||||
|
```
|
||||||
|
|
||||||
|
`--input-fp8-hf-path` is used to fetch weights scalars in original DeepSeek V3 repo. The file can be generated by
|
||||||
|
|
||||||
|
```
|
||||||
|
python bf16_cast_fp8.py \
|
||||||
|
--input-fp8-hf-path "${DIST_FS_ROOT}/DeepSeek-V3-0324"
|
||||||
|
```
|
||||||
|
|
||||||
|
The script creates fp8 safetensors inside `${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate` and `${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate/model.safetensors.index.json`
|
||||||
|
|
||||||
|
###### 2. Copy the following configs from your bf16 checkpoint to the folder
|
||||||
|
|
||||||
|
```
|
||||||
|
DEST=${DIST_FS_ROOT}/DeepSeek-sft-bf16-FP8E4M3_block128x128-fp8-gate
|
||||||
|
|
||||||
|
cp $BF16_CHECK_POINT/config.json
|
||||||
|
cp $BF16_CHECK_POINT/configuration_deepseek.py ${DEST}/
|
||||||
|
cp $BF16_CHECK_POINT/modeling_deepseek.py ${DEST}/
|
||||||
|
cp $BF16_CHECK_POINT/tokenizer.json ${DEST}/
|
||||||
|
cp $BF16_CHECK_POINT/tokenizer_config.json ${DEST}/
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure you have add the following dict is added into config.json :
|
||||||
|
|
||||||
|
```
|
||||||
|
"quantization_config": {
|
||||||
|
"activation_scheme": "dynamic",
|
||||||
|
"fmt": "e4m3",
|
||||||
|
"quant_method": "fp8",
|
||||||
|
"weight_block_size": [128, 128],
|
||||||
|
"ignored_layers": [".*lm_head"],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
We will make simple class to automate the process later.
|
||||||
|
|
||||||
|
## BF16 upgrade training/inference
|
||||||
|
|
||||||
|
This was originally created for inferenc eon non-fp8 capable chips. See details from [fp8_cast_bf16.py](./fp8_cast_bf16.py).
|
5
inference/auto_fp8/__init__.py
Normal file
5
inference/auto_fp8/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .config import BaseQuantizeConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseQuantizeConfig",
|
||||||
|
]
|
47
inference/auto_fp8/config.py
Normal file
47
inference/auto_fp8/config.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Adapted from AutoFP8 (deprecated project)
|
||||||
|
import re
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class BaseQuantizeConfig:
|
||||||
|
"""Configuration for model quantization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_method: Type/precision of quantization method to use.
|
||||||
|
At the moment, this is just "fp8" which specifically means
|
||||||
|
the fp8_e4m3 format in pytorch.
|
||||||
|
activation_scheme: Choice of either "dynamic" or "static" quantization
|
||||||
|
of activations. If "static", then calibration samples are required
|
||||||
|
during quantization to produce accurate per-tensor scales for
|
||||||
|
activations of Linear modules.
|
||||||
|
ignore_patterns: List of patterns used to ignore layers. If a string
|
||||||
|
starts with "re:", then everything afterwards is used as python
|
||||||
|
regex style matching i.e. re.search(), for each Linear layer.
|
||||||
|
By default, "re:.*lm_head" is included to ignore the embedding
|
||||||
|
Linear layer usually at the end of decoder LLMs
|
||||||
|
kv_cache_quant_targets: Tuple of Linear module names to target for
|
||||||
|
calibration of the output scales for KV cache quantization.
|
||||||
|
Usually, these should be `("k_proj", "v_proj")`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_method: str = "fp8",
|
||||||
|
activation_scheme: str = "static",
|
||||||
|
ignore_patterns: List[str] = ["re:.*lm_head"],
|
||||||
|
kv_cache_quant_targets: Optional[Tuple[str]] = None,
|
||||||
|
):
|
||||||
|
if quant_method != "fp8":
|
||||||
|
raise ValueError("Only FP8 quantization is supported.")
|
||||||
|
if activation_scheme not in ["static", "dynamic"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid activation_scheme. Choose either 'static' or 'dynamic'."
|
||||||
|
)
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.activation_scheme = activation_scheme
|
||||||
|
self.re_ignore_patterns = ignore_patterns
|
||||||
|
self.ignore_patterns = [
|
||||||
|
re.compile(regex_pat, re.VERBOSE) for regex_pat in ignore_patterns
|
||||||
|
]
|
||||||
|
self.kv_cache_quant_targets = kv_cache_quant_targets
|
||||||
|
self.ignored_layers = []
|
253
inference/bf16_cast_fp8.py
Normal file
253
inference/bf16_cast_fp8.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from auto_fp8 import BaseQuantizeConfig
|
||||||
|
from kernel import fp8_weight_block_wise_quant
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to get tensor from the correct file
|
||||||
|
def has_tensor(weight_map, loaded_files, fp8_path, tensor_name):
|
||||||
|
"""
|
||||||
|
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_name (str): The name of the tensor to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The retrieved tensor.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the tensor does not exist in the safetensor file.
|
||||||
|
"""
|
||||||
|
file_name = weight_map[tensor_name]
|
||||||
|
if file_name not in loaded_files:
|
||||||
|
file_path = os.path.join(fp8_path, file_name)
|
||||||
|
loaded_files[file_name] = load_file(file_path, device="cuda")
|
||||||
|
return loaded_files[file_name][tensor_name]
|
||||||
|
|
||||||
|
|
||||||
|
def find_ignored(regex_pat, weight_name):
|
||||||
|
searched = regex_pat.search(weight_name)
|
||||||
|
if searched is not None:
|
||||||
|
# print(f"find : {searched.string}")
|
||||||
|
return searched.string
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def find_one_ignored(regex_pat_list, weight_name):
|
||||||
|
for regex_pat in regex_pat_list:
|
||||||
|
searched = find_ignored(regex_pat, weight_name)
|
||||||
|
if searched is not None:
|
||||||
|
return searched
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
quantize_config = BaseQuantizeConfig(
|
||||||
|
quant_method="fp8",
|
||||||
|
activation_scheme="dynamic",
|
||||||
|
ignore_patterns=[".*lm_head"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
|
||||||
|
"""
|
||||||
|
Quantize BF16 to FP8 (OCP E4M3) and saves the converted weights.
|
||||||
|
|
||||||
|
This function reads BF16 weights from the specified directory, converts them to FP8 (OCP E4M3),
|
||||||
|
and saves the converted weights to another specified directory. It also updates the
|
||||||
|
model index file to reflect the changes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bf16_path (str): The path to the directory containing the BF16 weights and model index file.
|
||||||
|
fp8_path (str): The path to the directory where the converted FP8 (OCP E4M3) weights will be saved.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If a required scale_inv tensor is missing for a weight.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The function assumes that the BF16 weights are stored in safetensor files.
|
||||||
|
- The function update the model index file to add references to scale_inv tensors.
|
||||||
|
"""
|
||||||
|
# torch.set_default_dtype(torch.bfloat16)
|
||||||
|
os.makedirs(fp8_path, exist_ok=True)
|
||||||
|
|
||||||
|
model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
||||||
|
with open(model_index_file, "r") as f:
|
||||||
|
model_index = json.load(f)
|
||||||
|
weight_map = model_index["weight_map"]
|
||||||
|
|
||||||
|
# Cache for loaded safetensor files
|
||||||
|
loaded_files = {}
|
||||||
|
bf16_weight_names = []
|
||||||
|
bf16_weight_scale_inv = {}
|
||||||
|
|
||||||
|
safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
|
||||||
|
safetensor_files.sort()
|
||||||
|
for safetensor_file in tqdm(safetensor_files):
|
||||||
|
file_name = os.path.basename(safetensor_file)
|
||||||
|
current_state_dict = load_file(safetensor_file, device="cuda")
|
||||||
|
loaded_files[file_name] = current_state_dict
|
||||||
|
|
||||||
|
new_state_dict = {}
|
||||||
|
for weight_name, weight in current_state_dict.items():
|
||||||
|
if (
|
||||||
|
find_one_ignored(quantize_config.ignore_patterns, weight_name)
|
||||||
|
is not None
|
||||||
|
):
|
||||||
|
# print(f"skipping {weight_name} dtype={weight.dtype}...")
|
||||||
|
new_state_dict[weight_name] = weight
|
||||||
|
continue
|
||||||
|
elif weight.element_size() >= 2: # BF16 / Float weight
|
||||||
|
|
||||||
|
if (
|
||||||
|
ref_weights_scale_inv_map is not None
|
||||||
|
and ref_weights_scale_inv_map.get(weight_name, None) is None
|
||||||
|
):
|
||||||
|
# print(f"skipping {weight_name} dtype={weight.dtype}...")
|
||||||
|
new_state_dict[weight_name] = weight
|
||||||
|
continue
|
||||||
|
|
||||||
|
scale_inv_name = f"{weight_name}_scale_inv"
|
||||||
|
|
||||||
|
bf16_weight_names.append(weight_name)
|
||||||
|
bf16_weight_scale_inv[scale_inv_name] = file_name
|
||||||
|
|
||||||
|
fp8_weight, scale_inv = fp8_weight_block_wise_quant(weight)
|
||||||
|
new_state_dict[weight_name] = fp8_weight
|
||||||
|
new_state_dict[scale_inv_name] = scale_inv
|
||||||
|
else:
|
||||||
|
# print(f"skipping {weight_name} dtype={weight.dtype} ...")
|
||||||
|
new_state_dict[weight_name] = weight
|
||||||
|
|
||||||
|
new_safetensor_file = os.path.join(fp8_path, file_name)
|
||||||
|
save_file(new_state_dict, new_safetensor_file)
|
||||||
|
|
||||||
|
# Memory management: keep only the 2 most recently used files
|
||||||
|
if len(loaded_files) > 2:
|
||||||
|
oldest_file = next(iter(loaded_files))
|
||||||
|
del loaded_files[oldest_file]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Update model index
|
||||||
|
new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||||
|
|
||||||
|
# TODO (yiakwy) : rewrite with dict.update
|
||||||
|
for weight_name in bf16_weight_names:
|
||||||
|
scale_inv_name = f"{weight_name}_scale_inv"
|
||||||
|
weight_map[scale_inv_name] = bf16_weight_scale_inv[scale_inv_name]
|
||||||
|
|
||||||
|
with open(new_model_index_file, "w") as f:
|
||||||
|
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE (yiakwy) : huggingface library will add some parameters different from deepseek V3, we will modify this later, currently
|
||||||
|
# we recommend to update config.json manually
|
||||||
|
def update_quant_model_config(bf16_cast_fp8_path):
|
||||||
|
cfg = AutoConfig.from_pretrained(bf16_cast_fp8_path)
|
||||||
|
|
||||||
|
static_q_dict = {
|
||||||
|
"quantization_config": {
|
||||||
|
"activation_scheme": quantize_config.activation_scheme,
|
||||||
|
"fmt": "e4m3",
|
||||||
|
"quant_method": "fp8",
|
||||||
|
"weight_block_size": [128, 128],
|
||||||
|
"ignored_layers": quantize_config.re_ignore_patterns,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.update(static_q_dict)
|
||||||
|
cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json"))
|
||||||
|
|
||||||
|
|
||||||
|
def read_weight_inv_list(fp8_path):
|
||||||
|
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||||
|
with open(model_index_file, "r") as f:
|
||||||
|
model_index = json.load(f)
|
||||||
|
weight_map = model_index["weight_map"]
|
||||||
|
weights_with_scale_inv_map = {}
|
||||||
|
|
||||||
|
loaded_files = {}
|
||||||
|
fp8_weight_names = []
|
||||||
|
|
||||||
|
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
||||||
|
safetensor_files.sort()
|
||||||
|
for safetensor_file in tqdm(safetensor_files):
|
||||||
|
file_name = os.path.basename(safetensor_file)
|
||||||
|
current_state_dict = load_file(safetensor_file, device="cuda")
|
||||||
|
loaded_files[file_name] = current_state_dict
|
||||||
|
|
||||||
|
new_state_dict = {}
|
||||||
|
for weight_name, weight in current_state_dict.items():
|
||||||
|
if weight_name.endswith("_scale_inv"):
|
||||||
|
continue
|
||||||
|
elif weight.element_size() == 1: # FP8 weight:
|
||||||
|
scale_inv_name = f"{weight_name}_scale_inv"
|
||||||
|
try:
|
||||||
|
# Get scale_inv from the correct file
|
||||||
|
scale_inv = has_tensor(
|
||||||
|
weight_map, loaded_files, fp8_path, scale_inv_name
|
||||||
|
)
|
||||||
|
fp8_weight_names.append(weight_name)
|
||||||
|
weights_with_scale_inv_map[weight_name] = weight_map[scale_inv_name]
|
||||||
|
except KeyError:
|
||||||
|
print(
|
||||||
|
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
|
||||||
|
)
|
||||||
|
new_state_dict[weight_name] = weight
|
||||||
|
|
||||||
|
# Memory management: keep only the 2 most recently used files
|
||||||
|
if len(loaded_files) > 2:
|
||||||
|
oldest_file = next(iter(loaded_files))
|
||||||
|
del loaded_files[oldest_file]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
weights_with_scale_inv = os.path.join(
|
||||||
|
fp8_path, "weight_with_scale_inv_map.index.json"
|
||||||
|
)
|
||||||
|
with open(weights_with_scale_inv, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{"metadata": {}, "weight_with_scale_inv_map": weights_with_scale_inv_map},
|
||||||
|
f,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--input-bf16-hf-path", type=str, required=False)
|
||||||
|
parser.add_argument("--output-fp8-hf-path", type=str, required=False)
|
||||||
|
parser.add_argument("--input-fp8-hf-path", type=str, required=False)
|
||||||
|
parser.add_argument("--input-new-fp8-hf-path", type=str, required=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if (
|
||||||
|
args.input_fp8_hf_path is not None
|
||||||
|
and args.input_bf16_hf_path is None
|
||||||
|
and args.output_fp8_hf_path is None
|
||||||
|
):
|
||||||
|
read_weight_inv_list(args.input_fp8_hf_path)
|
||||||
|
elif args.input_new_fp8_hf_path is not None:
|
||||||
|
update_quant_model_config(args.input_new_fp8_hf_path)
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
args.input_bf16_hf_path is not None and args.output_fp8_hf_path is not None
|
||||||
|
)
|
||||||
|
if args.input_fp8_hf_path is not None:
|
||||||
|
weights_with_scale_inv = os.path.join(
|
||||||
|
args.input_fp8_hf_path, "weight_with_scale_inv_map.index.json"
|
||||||
|
)
|
||||||
|
with open(weights_with_scale_inv, "r") as f:
|
||||||
|
model_index = json.load(f)
|
||||||
|
weight_with_scale_inv_map = model_index["weight_with_scale_inv_map"]
|
||||||
|
main(
|
||||||
|
args.input_bf16_hf_path,
|
||||||
|
args.output_fp8_hf_path,
|
||||||
|
ref_weights_scale_inv_map=weight_with_scale_inv_map,
|
||||||
|
)
|
@ -5,9 +5,74 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from triton import Config
|
from triton import Config
|
||||||
|
|
||||||
|
OCP_FP8E4M3_MAX = 448.0
|
||||||
|
|
||||||
|
FP8_MAX = OCP_FP8E4M3_MAX
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
def fp8_weight_block_wise_quant_kernel(
|
||||||
|
x_ptr,
|
||||||
|
y_ptr,
|
||||||
|
s_ptr,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
SCALED_BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
SCALED_BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
FP8_MAX: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid_m = tl.program_id(axis=0)
|
||||||
|
pid_n = tl.program_id(axis=1)
|
||||||
|
n = tl.cdiv(N, SCALED_BLOCK_SIZE_M)
|
||||||
|
offs_m = pid_m * SCALED_BLOCK_SIZE_M + tl.arange(0, SCALED_BLOCK_SIZE_M)
|
||||||
|
offs_n = pid_n * SCALED_BLOCK_SIZE_N + tl.arange(0, SCALED_BLOCK_SIZE_N)
|
||||||
|
offs = offs_m[:, None] * N + offs_n[None, :]
|
||||||
|
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
||||||
|
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
|
||||||
|
s = tl.max(tl.abs(x)) / FP8_MAX
|
||||||
|
y = x / s
|
||||||
|
y = y.to(y_ptr.dtype.element_ty)
|
||||||
|
tl.store(y_ptr + offs, y, mask=mask)
|
||||||
|
tl.store(s_ptr + pid_m * n + pid_n, s)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def fp8_weight_block_wise_quant(
|
||||||
|
x: torch.Tensor, scaled_block_size_m: int = 128, scaled_block_size_n: int = 128
|
||||||
|
):
|
||||||
|
assert x.is_contiguous(), "Input tensor must be contiguous"
|
||||||
|
assert x.dim() == 2, "Input tensor must have 2 dimensions"
|
||||||
|
# assert x.size(0) % scaled_block_size_m == 0 and x.size(1) % scaled_block_size_n == 0, \
|
||||||
|
# f"Dimensions of x must be divisible by scaled block_size (scale_block_size_m={scaled_block_size_m}x{scaled_block_size_n})"
|
||||||
|
M, N = x.size()
|
||||||
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||||
|
s = x.new_empty(
|
||||||
|
triton.cdiv(M, scaled_block_size_m),
|
||||||
|
triton.cdiv(N, scaled_block_size_n),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
grid = lambda meta: (
|
||||||
|
triton.cdiv(M, meta["SCALED_BLOCK_SIZE_M"]),
|
||||||
|
triton.cdiv(N, meta["SCALED_BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
fp8_weight_block_wise_quant_kernel[grid](
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
s,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
SCALED_BLOCK_SIZE_M=scaled_block_size_m,
|
||||||
|
SCALED_BLOCK_SIZE_N=scaled_block_size_n,
|
||||||
|
FP8_MAX=FP8_MAX,
|
||||||
|
)
|
||||||
|
return y, s
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def act_quant_kernel(
|
||||||
|
x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, FP8_MAX: tl.constexpr
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
||||||
|
|
||||||
@ -23,14 +88,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||||
s = tl.max(tl.abs(x)) / 448.
|
s = tl.max(tl.abs(x)) / FP8_MAX
|
||||||
y = x / s
|
y = x / s
|
||||||
y = y.to(y_ptr.dtype.element_ty)
|
y = y.to(y_ptr.dtype.element_ty)
|
||||||
tl.store(y_ptr + offs, y)
|
tl.store(y_ptr + offs, y)
|
||||||
tl.store(s_ptr + pid, s)
|
tl.store(s_ptr + pid, s)
|
||||||
|
|
||||||
|
|
||||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
def act_quant(
|
||||||
|
x: torch.Tensor, block_size: int = 128
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Quantizes the input tensor `x` using block-wise quantization.
|
Quantizes the input tensor `x` using block-wise quantization.
|
||||||
|
|
||||||
@ -43,12 +110,14 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
|||||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||||
- A tensor of scaling factors with dtype `torch.float32`.
|
- A tensor of scaling factors with dtype `torch.float32`.
|
||||||
"""
|
"""
|
||||||
assert x.is_contiguous(), 'Input tensor must be contiguous'
|
assert x.is_contiguous(), "Input tensor must be contiguous"
|
||||||
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
|
assert (
|
||||||
|
x.size(-1) % block_size == 0
|
||||||
|
), f"Last dimension size must be divisible by block_size (block_size={block_size})"
|
||||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
||||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, FP8_MAX=FP8_MAX)
|
||||||
return y, s
|
return y, s
|
||||||
|
|
||||||
|
|
||||||
@ -81,7 +150,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|||||||
tl.store(y_ptr + offs, y, mask=mask)
|
tl.store(y_ptr + offs, y, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
def weight_dequant(
|
||||||
|
x: torch.Tensor, s: torch.Tensor, block_size: int = 128
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Dequantizes the given weight tensor using the provided scale tensor.
|
Dequantizes the given weight tensor using the provided scale tensor.
|
||||||
|
|
||||||
@ -96,28 +167,45 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
|||||||
Raises:
|
Raises:
|
||||||
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
||||||
"""
|
"""
|
||||||
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
|
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
|
||||||
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
|
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
|
||||||
M, N = x.size()
|
M, N = x.size()
|
||||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
grid = lambda meta: (
|
||||||
|
triton.cdiv(M, meta["BLOCK_SIZE"]),
|
||||||
|
triton.cdiv(N, meta["BLOCK_SIZE"]),
|
||||||
|
)
|
||||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
fp8_gemm_configs = [
|
fp8_gemm_configs = [
|
||||||
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
Config(
|
||||||
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
|
||||||
|
num_stages=num_stages,
|
||||||
|
num_warps=8,
|
||||||
|
)
|
||||||
|
for block_m in [16, 32, 64]
|
||||||
|
for block_n in [32, 64, 128]
|
||||||
|
for num_stages in [3, 4, 5, 6]
|
||||||
]
|
]
|
||||||
|
|
||||||
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
|
||||||
|
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
def fp8_gemm_kernel(
|
||||||
a_s_ptr, b_s_ptr,
|
a_ptr,
|
||||||
M, N: tl.constexpr, K: tl.constexpr,
|
b_ptr,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
c_ptr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
a_s_ptr,
|
||||||
BLOCK_SIZE_K: tl.constexpr):
|
b_s_ptr,
|
||||||
|
M,
|
||||||
|
N: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
||||||
|
|
||||||
@ -180,12 +268,17 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The result of the matrix multiplication.
|
torch.Tensor: The result of the matrix multiplication.
|
||||||
"""
|
"""
|
||||||
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
|
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
|
||||||
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
|
assert (
|
||||||
|
a_s.is_contiguous() and b_s.is_contiguous()
|
||||||
|
), "Scaling factor tensors must be contiguous"
|
||||||
K = a.size(-1)
|
K = a.size(-1)
|
||||||
M = a.numel() // K
|
M = a.numel() // K
|
||||||
N = b.size(0)
|
N = b.size(0)
|
||||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
grid = lambda META: (
|
||||||
|
triton.cdiv(M, META["BLOCK_SIZE_M"]),
|
||||||
|
triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||||
return c
|
return c
|
||||||
|
Loading…
Reference in New Issue
Block a user