diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index c25cea1..bce5784 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -3,6 +3,7 @@ import json from argparse import ArgumentParser from glob import glob from tqdm import tqdm +import gc import torch from safetensors.torch import load_file, save_file @@ -97,7 +98,12 @@ def main(fp8_path, bf16_path): if len(loaded_files) > 2: oldest_file = next(iter(loaded_files)) del loaded_files[oldest_file] - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.mps.is_available(): + torch.mps.empty_cache() + else: + gc.collect() # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")