From 8a45dade7b4d1d4a234e27c44b9f680936c16711 Mon Sep 17 00:00:00 2001 From: qpwo <10591373+qpwo@users.noreply.github.com> Date: Fri, 24 Jan 2025 01:18:42 -0800 Subject: [PATCH] simplify loop in convert.py --- inference/convert.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 12a9ba1..617bf4c 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -64,9 +64,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp, these_mps): assert key in mapping new_key, dim = mapping[key] name = name.replace(key, new_key) - for i in range(mp): - if i not in these_mps: - continue + for i in these_mps: new_param = param if "experts" in name and "shared_experts" not in name: idx = int(name.split(".")[-3]) @@ -80,9 +78,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp, these_mps): os.makedirs(save_path, exist_ok=True) - for i in range(mp): - if i not in these_mps: - continue + for i in these_mps: p = os.path.join(save_path, f"model{i}-mp{mp}.safetensors") if os.path.exists(p): print(f"{p=}: already exists, skipping")