Memory management update

This commit is contained in:
Nripesh Niketan 2025-02-02 10:41:20 +00:00
parent b6e3910fd0
commit 73efe7c631

View File

@ -3,6 +3,7 @@ import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import gc
import torch
from safetensors.torch import load_file, save_file
@ -97,7 +98,12 @@ def main(fp8_path, bf16_path):
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.mps.is_available():
torch.mps.empty_cache()
else:
gc.collect()
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")