From 897291478c05d2fbe69705bd377577dfbec4b4ca Mon Sep 17 00:00:00 2001 From: Tanmay Das <101689672+tdas3001@users.noreply.github.com> Date: Mon, 10 Feb 2025 18:40:56 -0500 Subject: [PATCH] Refactor checkpoint conversion script for improved readability and efficiency --- inference/convert.py | 54 +++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 6d85ccc..3e0ac46 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -2,6 +2,7 @@ import os import shutil from argparse import ArgumentParser from glob import glob +from pathlib import Path from tqdm import tqdm, trange import torch @@ -43,46 +44,55 @@ def main(hf_ckpt_path, save_path, n_experts, mp): Returns: None """ + assert mp > 0, "Model parallelism (mp) must be greater than 0" + torch.set_num_threads(8) n_local_experts = n_experts // mp state_dicts = [{} for _ in range(mp)] + + hf_ckpt_path = Path(hf_ckpt_path) + save_path = Path(save_path) + save_path.mkdir(parents=True, exist_ok=True) - for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): + for file_path in tqdm(hf_ckpt_path.glob("*.safetensors")): with safe_open(file_path, framework="pt", device="cpu") as f: for name in f.keys(): if "model.layers.61" in name: continue + param: torch.Tensor = f.get_tensor(name) + if name.startswith("model."): name = name[len("model."):] - name = name.replace("self_attn", "attn") - name = name.replace("mlp", "ffn") - name = name.replace("weight_scale_inv", "scale") - name = name.replace("e_score_correction_bias", "bias") + + name = ( + name.replace("self_attn", "attn") + .replace("mlp", "ffn") + .replace("weight_scale_inv", "scale") + .replace("e_score_correction_bias", "bias") + ) + key = name.split(".")[-2] assert key in mapping, f"Key {key} not found in mapping" new_key, dim = mapping[key] name = name.replace(key, new_key) - for i in range(mp): - new_param = param - if "experts" in name and "shared_experts" not in name: - idx = int(name.split(".")[-3]) - if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: - continue - elif dim is not None: - assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" - shard_size = param.size(dim) // mp - new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() - state_dicts[i][name] = new_param - os.makedirs(save_path, exist_ok=True) + if "experts" in name and "shared_experts" not in name: + idx = int(name.split(".")[-3]) + target_index = idx // n_local_experts + if target_index < mp: + state_dicts[target_index][name] = param + elif dim is not None: + assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" + shard_size = param.size(dim) // mp + for i in range(mp): + state_dicts[i][name] = param[:, i * shard_size : (i + 1) * shard_size] if dim == 1 else param[i * shard_size : (i + 1) * shard_size] for i in trange(mp): - save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) + save_file(state_dicts[i], save_path / f"model{i}-mp{mp}.safetensors") - for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): - new_file_path = os.path.join(save_path, os.path.basename(file_path)) - shutil.copyfile(file_path, new_file_path) + for file_path in hf_ckpt_path.glob("*token*"): + shutil.copyfile(file_path, save_path / file_path.name) if __name__ == "__main__": @@ -92,5 +102,7 @@ if __name__ == "__main__": parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() + assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" + main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)