Refactor checkpoint conversion script for improved readability and efficiency

This commit is contained in:
Tanmay Das 2025-02-10 18:40:56 -05:00 committed by GitHub
parent 2f7b80eece
commit 897291478c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@ import os
import shutil import shutil
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from pathlib import Path
from tqdm import tqdm, trange from tqdm import tqdm, trange
import torch import torch
@ -43,46 +44,55 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
Returns: Returns:
None None
""" """
assert mp > 0, "Model parallelism (mp) must be greater than 0"
torch.set_num_threads(8) torch.set_num_threads(8)
n_local_experts = n_experts // mp n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)] state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): hf_ckpt_path = Path(hf_ckpt_path)
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
for file_path in tqdm(hf_ckpt_path.glob("*.safetensors")):
with safe_open(file_path, framework="pt", device="cpu") as f: with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys(): for name in f.keys():
if "model.layers.61" in name: if "model.layers.61" in name:
continue continue
param: torch.Tensor = f.get_tensor(name) param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."): if name.startswith("model."):
name = name[len("model."):] name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn") name = (
name = name.replace("weight_scale_inv", "scale") name.replace("self_attn", "attn")
name = name.replace("e_score_correction_bias", "bias") .replace("mlp", "ffn")
.replace("weight_scale_inv", "scale")
.replace("e_score_correction_bias", "bias")
)
key = name.split(".")[-2] key = name.split(".")[-2]
assert key in mapping, f"Key {key} not found in mapping" assert key in mapping, f"Key {key} not found in mapping"
new_key, dim = mapping[key] new_key, dim = mapping[key]
name = name.replace(key, new_key) name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True) if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
target_index = idx // n_local_experts
if target_index < mp:
state_dicts[target_index][name] = param
elif dim is not None:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
for i in range(mp):
state_dicts[i][name] = param[:, i * shard_size : (i + 1) * shard_size] if dim == 1 else param[i * shard_size : (i + 1) * shard_size]
for i in trange(mp): for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) save_file(state_dicts[i], save_path / f"model{i}-mp{mp}.safetensors")
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): for file_path in hf_ckpt_path.glob("*token*"):
new_file_path = os.path.join(save_path, os.path.basename(file_path)) shutil.copyfile(file_path, save_path / file_path.name)
shutil.copyfile(file_path, new_file_path)
if __name__ == "__main__": if __name__ == "__main__":
@ -92,5 +102,7 @@ if __name__ == "__main__":
parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args() args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)