mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-07-04 23:41:37 -04:00
update script and verified correctness
This commit is contained in:
parent
44c403f0d8
commit
31bdaf1112
@ -36,7 +36,7 @@ def has_tensor(weight_map, loaded_files, fp8_path, tensor_name):
|
||||
def find_ignored(regex_pat, weight_name):
|
||||
searched = regex_pat.search(weight_name)
|
||||
if searched is not None:
|
||||
print(f"find : {searched.string}")
|
||||
# print(f"find : {searched.string}")
|
||||
return searched.string
|
||||
return None
|
||||
|
||||
@ -52,7 +52,7 @@ def find_one_ignored(regex_pat_list, weight_name):
|
||||
quantize_config = BaseQuantizeConfig(
|
||||
quant_method="fp8",
|
||||
activation_scheme="dynamic",
|
||||
ignore_patterns=[".*lm_head", ".*gate"],
|
||||
ignore_patterns=[".*lm_head"],
|
||||
)
|
||||
|
||||
|
||||
@ -86,6 +86,7 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
|
||||
# Cache for loaded safetensor files
|
||||
loaded_files = {}
|
||||
bf16_weight_names = []
|
||||
bf16_weight_scale_inv = {}
|
||||
|
||||
safetensor_files = list(glob(os.path.join(bf16_path, "*.safetensors")))
|
||||
safetensor_files.sort()
|
||||
@ -100,25 +101,30 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
|
||||
find_one_ignored(quantize_config.ignore_patterns, weight_name)
|
||||
is not None
|
||||
):
|
||||
# print(f"skipping {weight_name} dtype={weight.dtype}...")
|
||||
new_state_dict[weight_name] = weight
|
||||
continue
|
||||
elif weight.element_size() == 2: # BF16 weight
|
||||
elif weight.element_size() >= 2: # BF16 / Float weight
|
||||
|
||||
if (
|
||||
ref_weights_scale_inv_map is not None
|
||||
and ref_weights_scale_inv_map.get(weight_name, None) is None
|
||||
):
|
||||
print(f"skipping {weight_name} ...")
|
||||
# print(f"skipping {weight_name} dtype={weight.dtype}...")
|
||||
new_state_dict[weight_name] = weight
|
||||
continue
|
||||
pass
|
||||
|
||||
scale_inv_name = f"{weight_name}_scale_inv"
|
||||
|
||||
bf16_weight_names.append(weight_name)
|
||||
bf16_weight_scale_inv[scale_inv_name] = file_name
|
||||
|
||||
fp8_weight, scale_inv = fp8_weight_block_wise_quant(weight)
|
||||
new_state_dict[weight_name] = fp8_weight
|
||||
new_state_dict[scale_inv_name] = scale_inv
|
||||
else:
|
||||
# print(f"skipping {weight_name} dtype={weight.dtype} ...")
|
||||
new_state_dict[weight_name] = weight
|
||||
pass
|
||||
|
||||
new_safetensor_file = os.path.join(fp8_path, file_name)
|
||||
save_file(new_state_dict, new_safetensor_file)
|
||||
@ -131,17 +137,18 @@ def main(bf16_path, fp8_path, ref_weights_scale_inv_map=None):
|
||||
|
||||
# Update model index
|
||||
new_model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||
|
||||
# TODO (yiakwy) : rewrite with dict.update
|
||||
for weight_name in bf16_weight_names:
|
||||
scale_inv_name = f"{weight_name}_scale_inv"
|
||||
if scale_inv_name in weight_map:
|
||||
weight_map.insert(scale_inv_name)
|
||||
pass
|
||||
weight_map[scale_inv_name] = bf16_weight_scale_inv[scale_inv_name]
|
||||
|
||||
with open(new_model_index_file, "w") as f:
|
||||
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
||||
pass
|
||||
|
||||
|
||||
# NOTE (yiakwy) : huggingface library will add some parameters different from deepseek V3, we will modify this later, currently
|
||||
# we recommend to update config.json manually
|
||||
def update_quant_model_config(bf16_cast_fp8_path):
|
||||
cfg = AutoConfig.from_pretrained(bf16_cast_fp8_path)
|
||||
|
||||
@ -156,8 +163,7 @@ def update_quant_model_config(bf16_cast_fp8_path):
|
||||
}
|
||||
|
||||
cfg.update(static_q_dict)
|
||||
cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json.bak"))
|
||||
pass
|
||||
cfg.to_json_file(os.path.join(bf16_cast_fp8_path, "config.json"))
|
||||
|
||||
|
||||
def read_weight_inv_list(fp8_path):
|
||||
@ -195,15 +201,12 @@ def read_weight_inv_list(fp8_path):
|
||||
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
|
||||
)
|
||||
new_state_dict[weight_name] = weight
|
||||
pass
|
||||
pass
|
||||
|
||||
# Memory management: keep only the 2 most recently used files
|
||||
if len(loaded_files) > 2:
|
||||
oldest_file = next(iter(loaded_files))
|
||||
del loaded_files[oldest_file]
|
||||
torch.cuda.empty_cache()
|
||||
pass
|
||||
|
||||
weights_with_scale_inv = os.path.join(
|
||||
fp8_path, "weight_with_scale_inv_map.index.json"
|
||||
@ -214,7 +217,6 @@ def read_weight_inv_list(fp8_path):
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -233,7 +235,6 @@ if __name__ == "__main__":
|
||||
read_weight_inv_list(args.input_fp8_hf_path)
|
||||
elif args.input_new_fp8_hf_path is not None:
|
||||
update_quant_model_config(args.input_new_fp8_hf_path)
|
||||
pass
|
||||
else:
|
||||
assert (
|
||||
args.input_bf16_hf_path is not None and args.output_fp8_hf_path is not None
|
||||
@ -244,9 +245,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
with open(weights_with_scale_inv, "r") as f:
|
||||
model_index = json.load(f)
|
||||
pass
|
||||
weight_with_scale_inv_map = model_index["weight_with_scale_inv_map"]
|
||||
pass
|
||||
main(
|
||||
args.input_bf16_hf_path,
|
||||
args.output_fp8_hf_path,
|
||||
|
Loading…
Reference in New Issue
Block a user