diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..8ceda81 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -88,10 +88,14 @@ def main(fp8_path, bf16_path): save_file(new_state_dict, new_safetensor_file) # Memory management: keep only the 2 most recently used files - if len(loaded_files) > 2: - oldest_file = next(iter(loaded_files)) - del loaded_files[oldest_file] - torch.cuda.empty_cache() + try: + if len(loaded_files) > 2: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + torch.cuda.empty_cache() + except RuntimeError as e: + print(f"Memory error: {e}") + # Implement fallback strategy or graceful exit # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")