Refactored convert.py

This commit is contained in:
pratiyankkumar 2025-01-28 09:20:16 +05:30
parent b5d872ead0
commit 70ff909fdc

View File

@ -2,13 +2,19 @@ import os
import shutil import shutil
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm, trange from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from safetensors.torch import safe_open, save_file from safetensors.torch import safe_open, save_file
from tqdm import tqdm, trange
# Constants and type definitions
TensorMapping = Dict[str, Tuple[str, Optional[int]]]
StateDict = Dict[str, torch.Tensor]
mapping = { # Define mapping as a constant at module level
TENSOR_MAPPING: TensorMapping = {
"embed_tokens": ("embed", 0), "embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None), "input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None), "post_attention_layernorm": ("ffn_norm", None),
@ -29,68 +35,144 @@ mapping = {
"scale": ("scale", None), "scale": ("scale", None),
} }
def process_tensor_name(name: str) -> str:
def main(hf_ckpt_path, save_path, n_experts, mp):
""" """
Converts and saves model checkpoint files into a specified format. Process tensor name by removing prefixes and replacing common patterns.
Args: Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files. name: Original tensor name
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns: Returns:
None Processed tensor name
""" """
if name.startswith("model."):
name = name[len("model."):]
replacements = {
"self_attn": "attn",
"mlp": "ffn",
"weight_scale_inv": "scale",
"e_score_correction_bias": "bias"
}
for old, new in replacements.items():
name = name.replace(old, new)
return name
def shard_tensor(param: torch.Tensor, mp_idx: int, mp_count: int, dim: int) -> torch.Tensor:
"""
Shard a tensor along specified dimension for model parallelism.
Args:
param: Input tensor to shard
mp_idx: Index of current model parallel rank
mp_count: Total number of model parallel ranks
dim: Dimension along which to shard
Returns:
Sharded tensor slice
"""
if param.size(dim) % mp_count != 0:
raise ValueError(f"Tensor size {param.size(dim)} not divisible by mp_count {mp_count}")
shard_size = param.size(dim) // mp_count
return param.narrow(dim, mp_idx * shard_size, shard_size).contiguous()
def convert_checkpoint(
hf_ckpt_path: Union[str, Path],
save_path: Union[str, Path],
n_experts: int,
mp: int
) -> None:
"""
Convert and save model checkpoint files into a specified format.
Args:
hf_ckpt_path: Path to input checkpoint directory
save_path: Path to output directory for converted checkpoints
n_experts: Total number of experts in model
mp: Model parallelism factor
Raises:
ValueError: If n_experts is not divisible by mp
FileNotFoundError: If input path doesn't exist or contain safetensors
"""
if n_experts % mp != 0:
raise ValueError(f"Number of experts ({n_experts}) must be divisible by model parallel size ({mp})")
hf_ckpt_path = Path(hf_ckpt_path)
save_path = Path(save_path)
if not hf_ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint path {hf_ckpt_path} does not exist")
safetensor_files = list(hf_ckpt_path.glob("*.safetensors"))
if not safetensor_files:
raise FileNotFoundError(f"No safetensor files found in {hf_ckpt_path}")
torch.set_num_threads(8) torch.set_num_threads(8)
n_local_experts = n_experts // mp n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)] state_dicts: List[StateDict] = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): # Process each checkpoint file
for file_path in tqdm(safetensor_files, desc="Processing checkpoint files"):
with safe_open(file_path, framework="pt", device="cpu") as f: with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys(): for name in f.keys():
if "model.layers.61" in name: if "model.layers.61" in name:
continue continue
param: torch.Tensor = f.get_tensor(name) param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."): name = process_tensor_name(name)
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2] key = name.split(".")[-2]
assert key in mapping if key not in TENSOR_MAPPING:
new_key, dim = mapping[key] raise ValueError(f"Unknown tensor key: {key}")
new_key, dim = TENSOR_MAPPING[key]
name = name.replace(key, new_key) name = name.replace(key, new_key)
# Distribute tensors across model parallel ranks
for i in range(mp): for i in range(mp):
new_param = param new_param = param
if "experts" in name and "shared_experts" not in name: if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3]) idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: if not (i * n_local_experts <= idx < (i + 1) * n_local_experts):
continue continue
elif dim is not None: elif dim is not None:
assert param.size(dim) % mp == 0 new_param = shard_tensor(param, i, mp, dim)
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True) # Save converted checkpoints
save_path.mkdir(parents=True, exist_ok=True)
for i in trange(mp): for i in trange(mp, desc="Saving converted checkpoints"):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) output_file = save_path / f"model{i}-mp{mp}.safetensors"
save_file(state_dicts[i], str(output_file))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): # Copy tokenizer files
new_file_path = os.path.join(save_path, os.path.basename(file_path)) for file_path in hf_ckpt_path.glob("*token*"):
shutil.copyfile(file_path, new_file_path) shutil.copyfile(file_path, save_path / file_path.name)
def main():
"""Parse command line arguments and run the conversion."""
parser = ArgumentParser(description="Convert HuggingFace checkpoints to custom format")
parser.add_argument("--hf-ckpt-path", type=str, required=True,
help="Path to input HuggingFace checkpoint directory")
parser.add_argument("--save-path", type=str, required=True,
help="Path to output directory for converted checkpoints")
parser.add_argument("--n-experts", type=int, required=True,
help="Total number of experts in the model")
parser.add_argument("--model-parallel", type=int, required=True,
help="Model parallelism factor")
args = parser.parse_args()
try:
convert_checkpoint(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
except Exception as e:
print(f"Error during conversion: {str(e)}")
raise
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() main()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)