From f2636dd3666f08805b8501bbcc3bee5d05ac0892 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Sun, 2 Feb 2025 03:55:47 -0700 Subject: [PATCH] updated async file reading --- inference/convert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/inference/convert.py b/inference/convert.py index 6b818db..212e2e4 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -74,7 +74,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp): state_dicts = [{} for _ in range(mp)] tasks = [] for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): - async with sync.to_thread(safe_open, file_path, framework="pt", device="cpu") as f: + cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu") + async with cm as f: await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys())) os.makedirs(save_path, exist_ok=True)