diff --git a/inference/convert.py b/inference/convert.py index 6b818db..212e2e4 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -74,7 +74,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp): state_dicts = [{} for _ in range(mp)] tasks = [] for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): - async with sync.to_thread(safe_open, file_path, framework="pt", device="cpu") as f: + cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu") + async with cm as f: await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys())) os.makedirs(save_path, exist_ok=True)