mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-07-05 07:51:38 -04:00
254 lines
9.2 KiB
Python
254 lines
9.2 KiB
Python
import json
|
|
import os
|
|
import re
|
|
from argparse import ArgumentParser
|
|
from glob import glob
|
|
|
|
import torch
|
|
from auto_fp8 import BaseQuantizeConfig
|
|
from kernel import fp8_weight_block_wise_quant
|
|
from safetensors.torch import load_file, save_file
|
|
from tqdm import tqdm
|
|
from transformers import AutoConfig
|
|
|
|
|
|
# Helper function to get tensor from the correct file
|
|
def has_tensor(weight_map, loaded_files, fp8_path, tensor_name):
|
|
"""
|
|
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
|
|
|
|
Args:
|
|
tensor_name (str): The name of the tensor to retrieve.
|
|
|
|
Returns:
|
|
torch.Tensor: The retrieved tensor.
|
|
|
|
Raises:
|
|
KeyError: If the tensor does not exist in the safetensor file.
|
|
"""
|
|
file_name = weight_map[tensor_name]
|
|
if file_name not in loaded_files:
|
|
file_path = os.path.join(fp8_path, file_name)
|
|
loaded_files[file_name] = load_file(file_path, device="cuda")
|
|
return loaded_files[file_name][tensor_name]
|
|
|
|
|
|
def find_ignored(regex_pat, weight_name):
|
|
searched = regex_pat.search(weight_name)
|
|
if searched is not None:
|
|
# print(f"find : {searched.string}")
|
|
return searched.string
|
|
return None
|
|
|
|
|
|
def find_one_ignored(regex_pat_list, weight_name):
|
|
for regex_pat in regex_pat_list:
|
|
searched = find_ignored(regex_pat, weight_name)
|
|
if searched is not None:
|
|
return searched
|
|
return None
|
|
|
|
|
|
quantize_config = BaseQuantizeConfig(
|
|
quant_method="fp8",
|
|
activation_scheme="dynamic",
|
|
ignore_patterns=[".*lm_head"],
|
|
)
|
|
|
|
|
|
def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
|
|
"""
|
|
Quantize BF16 to FP8 (OCP E4M3) and saves the converted weights.
|
|
|
|
This function reads BF16 weights from the specified directory, converts them to FP8 (OCP E4M3),
|
|
and saves the converted weights to another specified directory. It also updates the
|
|
model index file to reflect the changes.
|
|
|
|
Args:
|
|
bf16_path (str): The path to the directory containing the BF16 weights and model index file.
|
|
fp8_path (str): The path to the directory where the converted FP8 (OCP E4M3) weights will be saved.
|
|
|
|
Raises:
|
|
KeyError: If a required scale_inv tensor is missing for a weight.
|
|
|
|
Notes:
|
|
- The function assumes that the BF16 weights are stored in safetensor files.
|
|
- The function update the model index file to add references to scale_inv tensors.
|
|
"""
|
|
# torch.set_default_dtype(torch.bfloat16)
|
|
os.makedirs(fp8_path, exist_ok=True)
|
|
|
|
model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
|
with open(model_index_file, "r") as f:
|
|
model_index = json.load(f)
|
|
weight_map = model_index["weight_map"]
|
|
|
|
# Cache for loaded safetensor files
|
|
loaded_files = {}
|
|
bf16_weight_names = []
|
|
bf16_weight_scale_inv = {}
|
|
|
|
safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
|
|
safetensor_files.sort()
|
|
for safetensor_file in tqdm(safetensor_files):
|
|
file_name = os.path.basename(safetensor_file)
|
|
current_state_dict = load_file(safetensor_file, device="cuda")
|
|
loaded_files[file_name] = current_state_dict
|
|
|
|
new_state_dict = {}
|
|
for weight_name, weight in current_state_dict.items():
|
|
if (
|
|
find_one_ignored(quantize_config.ignore_patterns, weight_name)
|
|
is not None
|
|
):
|
|
# print(f"skipping {weight_name} dtype={weight.dtype}...")
|
|
new_state_dict[weight_name] = weight
|
|
continue
|
|
elif weight.element_size() >= 2: # BF16 / Float weight
|
|
|
|
if (
|
|
ref_weights_scale_inv_map is not None
|
|
and ref_weights_scale_inv_map.get(weight_name, None) is None
|
|
):
|
|
# print(f"skipping {weight_name} dtype={weight.dtype}...")
|
|
new_state_dict[weight_name] = weight
|
|
continue
|
|
|
|
scale_inv_name = f"{weight_name}_scale_inv"
|
|
|
|
bf16_weight_names.append(weight_name)
|
|
bf16_weight_scale_inv[scale_inv_name] = file_name
|
|
|
|
fp8_weight, scale_inv = fp8_weight_block_wise_quant(weight)
|
|
new_state_dict[weight_name] = fp8_weight
|
|
new_state_dict[scale_inv_name] = scale_inv
|
|
else:
|
|
# print(f"skipping {weight_name} dtype={weight.dtype} ...")
|
|
new_state_dict[weight_name] = weight
|
|
|
|
new_safetensor_file = os.path.join(fp8_path, file_name)
|
|
save_file(new_state_dict, new_safetensor_file)
|
|
|
|
# Memory management: keep only the 2 most recently used files
|
|
if len(loaded_files) > 2:
|
|
oldest_file = next(iter(loaded_files))
|
|
del loaded_files[oldest_file]
|
|
torch.cuda.empty_cache()
|
|
|
|
# Update model index
|
|
new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
|
|
|
# TODO (yiakwy) : rewrite with dict.update
|
|
for weight_name in bf16_weight_names:
|
|
scale_inv_name = f"{weight_name}_scale_inv"
|
|
weight_map[scale_inv_name] = bf16_weight_scale_inv[scale_inv_name]
|
|
|
|
with open(new_model_index_file, "w") as f:
|
|
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
|
|
|
|
|
# NOTE (yiakwy) : huggingface library will add some parameters different from deepseek V3, we will modify this later, currently
|
|
# we recommend to update config.json manually
|
|
def update_quant_model_config(bf16_cast_fp8_path):
|
|
cfg = AutoConfig.from_pretrained(bf16_cast_fp8_path)
|
|
|
|
static_q_dict = {
|
|
"quantization_config": {
|
|
"activation_scheme": quantize_config.activation_scheme,
|
|
"fmt": "e4m3",
|
|
"quant_method": "fp8",
|
|
"weight_block_size": [128, 128],
|
|
"ignored_layers": quantize_config.re_ignore_patterns,
|
|
}
|
|
}
|
|
|
|
cfg.update(static_q_dict)
|
|
cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json"))
|
|
|
|
|
|
def read_weight_inv_list(fp8_path):
|
|
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
|
with open(model_index_file, "r") as f:
|
|
model_index = json.load(f)
|
|
weight_map = model_index["weight_map"]
|
|
weights_with_scale_inv_map = {}
|
|
|
|
loaded_files = {}
|
|
fp8_weight_names = []
|
|
|
|
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
|
safetensor_files.sort()
|
|
for safetensor_file in tqdm(safetensor_files):
|
|
file_name = os.path.basename(safetensor_file)
|
|
current_state_dict = load_file(safetensor_file, device="cuda")
|
|
loaded_files[file_name] = current_state_dict
|
|
|
|
new_state_dict = {}
|
|
for weight_name, weight in current_state_dict.items():
|
|
if weight_name.endswith("_scale_inv"):
|
|
continue
|
|
elif weight.element_size() == 1: # FP8 weight:
|
|
scale_inv_name = f"{weight_name}_scale_inv"
|
|
try:
|
|
# Get scale_inv from the correct file
|
|
scale_inv = has_tensor(
|
|
weight_map, loaded_files, fp8_path, scale_inv_name
|
|
)
|
|
fp8_weight_names.append(weight_name)
|
|
weights_with_scale_inv_map[weight_name] = weight_map[scale_inv_name]
|
|
except KeyError:
|
|
print(
|
|
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
|
|
)
|
|
new_state_dict[weight_name] = weight
|
|
|
|
# Memory management: keep only the 2 most recently used files
|
|
if len(loaded_files) > 2:
|
|
oldest_file = next(iter(loaded_files))
|
|
del loaded_files[oldest_file]
|
|
torch.cuda.empty_cache()
|
|
|
|
weights_with_scale_inv = os.path.join(
|
|
fp8_path, "weight_with_scale_inv_map.index.json"
|
|
)
|
|
with open(weights_with_scale_inv, "w") as f:
|
|
json.dump(
|
|
{"metadata": {}, "weight_with_scale_inv_map": weights_with_scale_inv_map},
|
|
f,
|
|
indent=2,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--input-bf16-hf-path", type=str, required=False)
|
|
parser.add_argument("--output-fp8-hf-path", type=str, required=False)
|
|
parser.add_argument("--input-fp8-hf-path", type=str, required=False)
|
|
parser.add_argument("--input-new-fp8-hf-path", type=str, required=False)
|
|
args = parser.parse_args()
|
|
|
|
if (
|
|
args.input_fp8_hf_path is not None
|
|
and args.input_bf16_hf_path is None
|
|
and args.output_fp8_hf_path is None
|
|
):
|
|
read_weight_inv_list(args.input_fp8_hf_path)
|
|
elif args.input_new_fp8_hf_path is not None:
|
|
update_quant_model_config(args.input_new_fp8_hf_path)
|
|
else:
|
|
assert (
|
|
args.input_bf16_hf_path is not None and args.output_fp8_hf_path is not None
|
|
)
|
|
if args.input_fp8_hf_path is not None:
|
|
weights_with_scale_inv = os.path.join(
|
|
args.input_fp8_hf_path, "weight_with_scale_inv_map.index.json"
|
|
)
|
|
with open(weights_with_scale_inv, "r") as f:
|
|
model_index = json.load(f)
|
|
weight_with_scale_inv_map = model_index["weight_with_scale_inv_map"]
|
|
main(
|
|
args.input_bf16_hf_path,
|
|
args.output_fp8_hf_path,
|
|
ref_weights_scale_inv_map=weight_with_scale_inv_map,
|
|
)
|