From 35244be39f41b5607c887d62acfd8fc1666d1597 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Sun, 2 Feb 2025 06:12:46 -0700 Subject: [PATCH] moved dir calls removed dir calls from for loops --- inference/convert.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 212e2e4..0310cb2 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -73,7 +73,9 @@ def main(hf_ckpt_path, save_path, n_experts, mp): n_local_experts = n_experts // mp state_dicts = [{} for _ in range(mp)] tasks = [] - for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): + tensor_dir = glob(os.path.join(hf_ckpt_path, "*.safetensors")) + token_dir = glob(os.path.join(hf_ckpt_path, "*token*")) + for file_path in tqdm(tensor_dir): cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu") async with cm as f: await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys())) @@ -85,7 +87,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): async def set_file_path(file_path): await sync.to_thread(shutil.copyfile, file_path, os.path.join(save_path, os.path.basename(file_path))) - await sync.gather(*(set_file_path(file_path) for file_path in glob(os.path.join(hf_ckpt_path, "*token*")))) + await sync.gather(*(set_file_path(file_path) for file_path in token_dir))