update script and verified correctness

This commit is contained in:
yiakwy-xpu-ml-framework-team 2025-07-01 17:40:04 +08:00
parent 44c403f0d8
commit 31bdaf1112

View File

@ -36,7 +36,7 @@ def has_tensor(weight_map, loaded_files, fp8_path, tensor_name):
def find_ignored(regex_pat, weight_name):
searched = regex_pat.search(weight_name)
if searched is not None:
print(f"find : {searched.string}")
# print(f"find : {searched.string}")
return searched.string
return None
@ -52,7 +52,7 @@ def find_one_ignored(regex_pat_list, weight_name):
quantize_config = BaseQuantizeConfig(
quant_method="fp8",
activation_scheme="dynamic",
ignore_patterns=[".*lm_head", ".*gate"],
ignore_patterns=[".*lm_head"],
)
@ -86,6 +86,7 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
# Cache for loaded safetensor files
loaded_files = {}
bf16_weight_names = []
bf16_weight_scale_inv = {}
safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
safetensor_files.sort()
@ -100,25 +101,30 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
find_one_ignored(quantize_config.ignore_patterns, weight_name)
is not None
):
# print(f"skipping {weight_name} dtype={weight.dtype}...")
new_state_dict[weight_name] = weight
continue
elif weight.element_size() == 2: # BF16 weight
elif weight.element_size() >= 2: # BF16 / Float weight
if (
ref_weights_scale_inv_map is not None
and ref_weights_scale_inv_map.get(weight_name, None) is None
):
print(f"skipping {weight_name} ...")
# print(f"skipping {weight_name} dtype={weight.dtype}...")
new_state_dict[weight_name] = weight
continue
pass
scale_inv_name = f"{weight_name}_scale_inv"
bf16_weight_names.append(weight_name)
bf16_weight_scale_inv[scale_inv_name] = file_name
fp8_weight, scale_inv = fp8_weight_block_wise_quant(weight)
new_state_dict[weight_name] = fp8_weight
new_state_dict[scale_inv_name] = scale_inv
else:
# print(f"skipping {weight_name} dtype={weight.dtype} ...")
new_state_dict[weight_name] = weight
pass
new_safetensor_file = os.path.join(fp8_path, file_name)
save_file(new_state_dict, new_safetensor_file)
@ -131,17 +137,18 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
# Update model index
new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
# TODO (yiakwy) : rewrite with dict.update
for weight_name in bf16_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.insert(scale_inv_name)
pass
weight_map[scale_inv_name] = bf16_weight_scale_inv[scale_inv_name]
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
pass
# NOTE (yiakwy) : huggingface library will add some parameters different from deepseek V3, we will modify this later, currently
# we recommend to update config.json manually
def update_quant_model_config(bf16_cast_fp8_path):
cfg = AutoConfig.from_pretrained(bf16_cast_fp8_path)
@ -156,8 +163,7 @@ def update_quant_model_config(bf16_cast_fp8_path):
}
cfg.update(static_q_dict)
cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json.bak"))
pass
cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json"))
def read_weight_inv_list(fp8_path):
@ -195,15 +201,12 @@ def read_weight_inv_list(fp8_path):
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
)
new_state_dict[weight_name] = weight
pass
pass
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
pass
weights_with_scale_inv = os.path.join(
fp8_path, "weight_with_scale_inv_map.index.json"
@ -214,7 +217,6 @@ def read_weight_inv_list(fp8_path):
f,
indent=2,
)
pass
if __name__ == "__main__":
@ -233,7 +235,6 @@ if __name__ == "__main__":
read_weight_inv_list(args.input_fp8_hf_path)
elif args.input_new_fp8_hf_path is not None:
update_quant_model_config(args.input_new_fp8_hf_path)
pass
else:
assert (
args.input_bf16_hf_path is not None and args.output_fp8_hf_path is not None
@ -244,9 +245,7 @@ if __name__ == "__main__":
)
with open(weights_with_scale_inv, "r") as f:
model_index = json.load(f)
pass
weight_with_scale_inv_map = model_index["weight_with_scale_inv_map"]
pass
main(
args.input_bf16_hf_path,
args.output_fp8_hf_path,