diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..ffb5f23 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -85,7 +85,7 @@ def main(fp8_path, bf16_path): new_state_dict[weight_name] = weight new_safetensor_file = os.path.join(bf16_path, file_name) - save_file(new_state_dict, new_safetensor_file) + save_file(new_state_dict, new_safetensor_file, metadata={"format": "pt"}) # Memory management: keep only the 2 most recently used files if len(loaded_files) > 2: