diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index cb7c11e..d6130ac 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -30,6 +30,7 @@ def main(fp8_path, bf16_path): return loaded_files[file_name][tensor_name] safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) + safetensor_files.sort() for safetensor_file in tqdm(safetensor_files): file_name = os.path.basename(safetensor_file) current_state_dict = load_file(safetensor_file, device="cuda")