diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index d6130ac..1b9735a 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -60,6 +60,7 @@ def main(fp8_path, bf16_path): if len(loaded_files) > 2: oldest_file = next(iter(loaded_files)) del loaded_files[oldest_file] + torch.cuda.empty_cache() # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")