diff --git a/inference/convert.py b/inference/convert.py index c606ce8..12a9ba1 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -2,7 +2,8 @@ import os import shutil from argparse import ArgumentParser from glob import glob -from tqdm import tqdm, trange +from tqdm import tqdm +from multiprocessing import Pool import torch from safetensors.torch import safe_open, save_file @@ -30,7 +31,7 @@ mapping = { } -def main(hf_ckpt_path, save_path, n_experts, mp): +def main(hf_ckpt_path, save_path, n_experts, mp, these_mps): """ Converts and saves model checkpoint files into a specified format. @@ -39,7 +40,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): save_path (str): Path to the directory where the converted checkpoint files will be saved. n_experts (int): Total number of experts in the model. mp (int): Model parallelism factor. - + Returns: None """ @@ -64,6 +65,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp): new_key, dim = mapping[key] name = name.replace(key, new_key) for i in range(mp): + if i not in these_mps: + continue new_param = param if "experts" in name and "shared_experts" not in name: idx = int(name.split(".")[-3]) @@ -77,8 +80,16 @@ def main(hf_ckpt_path, save_path, n_experts, mp): os.makedirs(save_path, exist_ok=True) - for i in trange(mp): - save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) + for i in range(mp): + if i not in these_mps: + continue + p = os.path.join(save_path, f"model{i}-mp{mp}.safetensors") + if os.path.exists(p): + print(f"{p=}: already exists, skipping") + continue + print(f"{p=}: saving") + save_file(state_dicts[i], p) + print(f"{p=}: done") for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): new_file_path = os.path.join(save_path, os.path.basename(file_path)) @@ -91,6 +102,11 @@ if __name__ == "__main__": parser.add_argument("--save-path", type=str, required=True) parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) + parser.add_argument("--num-procs", type=int, default=4) args = parser.parse_args() assert args.n_experts % args.model_parallel == 0 + with Pool(args.num_procs) as pool: + proc_args = [(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel, range(args.model_parallel)[i::args.num_procs]) for i in range(args.num_procs)] + pool.starmap(main, proc_args) + main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)