From dca08f2cfd5d451f4d0ce2279c6c2e67f712acb9 Mon Sep 17 00:00:00 2001 From: ajwise9 Date: Tue, 4 Feb 2025 16:36:07 +0000 Subject: [PATCH] fix(fp8_cast): Add robust memory management and error handling --- inference/fp8_cast_bf16.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..8ceda81 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -88,10 +88,14 @@ def main(fp8_path, bf16_path): save_file(new_state_dict, new_safetensor_file) # Memory management: keep only the 2 most recently used files - if len(loaded_files) > 2: - oldest_file = next(iter(loaded_files)) - del loaded_files[oldest_file] - torch.cuda.empty_cache() + try: + if len(loaded_files) > 2: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + torch.cuda.empty_cache() + except RuntimeError as e: + print(f"Memory error: {e}") + # Implement fallback strategy or graceful exit # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")