From 73efe7c63175f677f111f1df8aa84128b805fb02 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan Date: Sun, 2 Feb 2025 10:41:20 +0000 Subject: [PATCH] Memory management update --- inference/fp8_cast_bf16.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index c25cea1..bce5784 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -3,6 +3,7 @@ import json from argparse import ArgumentParser from glob import glob from tqdm import tqdm +import gc import torch from safetensors.torch import load_file, save_file @@ -97,7 +98,12 @@ def main(fp8_path, bf16_path): if len(loaded_files) > 2: oldest_file = next(iter(loaded_files)) del loaded_files[oldest_file] - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.mps.is_available(): + torch.mps.empty_cache() + else: + gc.collect() # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")