simplify loop in convert.py

This commit is contained in:
qpwo 2025-01-24 01:18:42 -08:00 committed by GitHub
parent 8c40067fb2
commit 8a45dade7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -64,9 +64,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp, these_mps):
assert key in mapping assert key in mapping
new_key, dim = mapping[key] new_key, dim = mapping[key]
name = name.replace(key, new_key) name = name.replace(key, new_key)
for i in range(mp): for i in these_mps:
if i not in these_mps:
continue
new_param = param new_param = param
if "experts" in name and "shared_experts" not in name: if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3]) idx = int(name.split(".")[-3])
@ -80,9 +78,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp, these_mps):
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
for i in range(mp): for i in these_mps:
if i not in these_mps:
continue
p = os.path.join(save_path, f"model{i}-mp{mp}.safetensors") p = os.path.join(save_path, f"model{i}-mp{mp}.safetensors")
if os.path.exists(p): if os.path.exists(p):
print(f"{p=}: already exists, skipping") print(f"{p=}: already exists, skipping")