Memory management update

This commit is contained in:
Nripesh Niketan 2025-02-02 10:41:20 +00:00
parent b6e3910fd0
commit 73efe7c631

View File

@ -3,6 +3,7 @@ import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
import gc
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
@ -97,7 +98,12 @@ def main(fp8_path, bf16_path):
if len(loaded_files) > 2: if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files)) oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file] del loaded_files[oldest_file]
torch.cuda.empty_cache() if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.mps.is_available():
torch.mps.empty_cache()
else:
gc.collect()
# Update model index # Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")