From 31bdaf1112272ef55f4b69a06b19b8a289b35515 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team <961186938@qq.com> Date: Tue, 1 Jul 2025 17:40:04 +0800 Subject: [PATCH] update script and verified correctness --- inference/bf16_cast_fp8.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/inference/bf16_cast_fp8.py b/inference/bf16_cast_fp8.py index 5a842c5..07bf706 100644 --- a/inference/bf16_cast_fp8.py +++ b/inference/bf16_cast_fp8.py @@ -36,7 +36,7 @@ def has_tensor(weight_map, loaded_files, fp8_path, tensor_name): def find_ignored(regex_pat, weight_name): searched = regex_pat.search(weight_name) if searched is not None: - print(f"find : {searched.string}") + # print(f"find : {searched.string}") return searched.string return None @@ -52,7 +52,7 @@ def find_one_ignored(regex_pat_list, weight_name): quantize_config = BaseQuantizeConfig( quant_method="fp8", activation_scheme="dynamic", - ignore_patterns=[".*lm_head", ".*gate"], + ignore_patterns=[".*lm_head"], ) @@ -86,6 +86,7 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None): # Cache for loaded safetensor files loaded_files = {} bf16_weight_names = [] + bf16_weight_scale_inv = {} safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors"))) safetensor_files.sort() @@ -100,25 +101,30 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None): find_one_ignored(quantize_config.ignore_patterns, weight_name) is not None ): + # print(f"skipping {weight_name} dtype={weight.dtype}...") + new_state_dict[weight_name] = weight continue - elif weight.element_size() == 2: # BF16 weight + elif weight.element_size() >= 2: # BF16 / Float weight if ( ref_weights_scale_inv_map is not None and ref_weights_scale_inv_map.get(weight_name, None) is None ): - print(f"skipping {weight_name} ...") + # print(f"skipping {weight_name} dtype={weight.dtype}...") + new_state_dict[weight_name] = weight continue - pass scale_inv_name = f"{weight_name}_scale_inv" + bf16_weight_names.append(weight_name) + bf16_weight_scale_inv[scale_inv_name] = file_name + fp8_weight, scale_inv = fp8_weight_block_wise_quant(weight) new_state_dict[weight_name] = fp8_weight new_state_dict[scale_inv_name] = scale_inv else: + # print(f"skipping {weight_name} dtype={weight.dtype} ...") new_state_dict[weight_name] = weight - pass new_safetensor_file = os.path.join(fp8_path, file_name) save_file(new_state_dict, new_safetensor_file) @@ -131,17 +137,18 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None): # Update model index new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") + + # TODO (yiakwy) : rewrite with dict.update for weight_name in bf16_weight_names: scale_inv_name = f"{weight_name}_scale_inv" - if scale_inv_name in weight_map: - weight_map.insert(scale_inv_name) - pass + weight_map[scale_inv_name] = bf16_weight_scale_inv[scale_inv_name] with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) - pass +# NOTE (yiakwy) : huggingface library will add some parameters different from deepseek V3, we will modify this later, currently +# we recommend to update config.json manually def update_quant_model_config(bf16_cast_fp8_path): cfg = AutoConfig.from_pretrained(bf16_cast_fp8_path) @@ -156,8 +163,7 @@ def update_quant_model_config(bf16_cast_fp8_path): } cfg.update(static_q_dict) - cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json.bak")) - pass + cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json")) def read_weight_inv_list(fp8_path): @@ -195,15 +201,12 @@ def read_weight_inv_list(fp8_path): f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion" ) new_state_dict[weight_name] = weight - pass - pass # Memory management: keep only the 2 most recently used files if len(loaded_files) > 2: oldest_file = next(iter(loaded_files)) del loaded_files[oldest_file] torch.cuda.empty_cache() - pass weights_with_scale_inv = os.path.join( fp8_path, "weight_with_scale_inv_map.index.json" @@ -214,7 +217,6 @@ def read_weight_inv_list(fp8_path): f, indent=2, ) - pass if __name__ == "__main__": @@ -233,7 +235,6 @@ if __name__ == "__main__": read_weight_inv_list(args.input_fp8_hf_path) elif args.input_new_fp8_hf_path is not None: update_quant_model_config(args.input_new_fp8_hf_path) - pass else: assert ( args.input_bf16_hf_path is not None and args.output_fp8_hf_path is not None @@ -244,9 +245,7 @@ if __name__ == "__main__": ) with open(weights_with_scale_inv, "r") as f: model_index = json.load(f) - pass weight_with_scale_inv_map = model_index["weight_with_scale_inv_map"] - pass main( args.input_bf16_hf_path, args.output_fp8_hf_path,