make convert.py use multiple processes

This commit is contained in:
qpwo 2025-01-24 01:12:19 -08:00 committed by GitHub
parent ee4c4ea32b
commit 8c40067fb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,8 @@ import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
from tqdm import tqdm
from multiprocessing import Pool
import torch
from safetensors.torch import safe_open, save_file
@ -30,7 +31,7 @@ mapping = {
}
def main(hf_ckpt_path, save_path, n_experts, mp):
def main(hf_ckpt_path, save_path, n_experts, mp, these_mps):
"""
Converts and saves model checkpoint files into a specified format.
@ -64,6 +65,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
if i not in these_mps:
continue
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
@ -77,8 +80,16 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for i in range(mp):
if i not in these_mps:
continue
p = os.path.join(save_path, f"model{i}-mp{mp}.safetensors")
if os.path.exists(p):
print(f"{p=}: already exists, skipping")
continue
print(f"{p=}: saving")
save_file(state_dicts[i], p)
print(f"{p=}: done")
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
@ -91,6 +102,11 @@ if __name__ == "__main__":
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
parser.add_argument("--num-procs", type=int, default=4)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0
with Pool(args.num_procs) as pool:
proc_args = [(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, range(args.model_parallel)[i::args.num_procs]) for i in range(args.num_procs)]
pool.starmap(main, proc_args)
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)