diff --git a/inference/convert.py b/inference/convert.py index c606ce8..595b26d 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -2,13 +2,19 @@ import os import shutil from argparse import ArgumentParser from glob import glob -from tqdm import tqdm, trange +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union import torch from safetensors.torch import safe_open, save_file +from tqdm import tqdm, trange +# Constants and type definitions +TensorMapping = Dict[str, Tuple[str, Optional[int]]] +StateDict = Dict[str, torch.Tensor] -mapping = { +# Define mapping as a constant at module level +TENSOR_MAPPING: TensorMapping = { "embed_tokens": ("embed", 0), "input_layernorm": ("attn_norm", None), "post_attention_layernorm": ("ffn_norm", None), @@ -29,68 +35,144 @@ mapping = { "scale": ("scale", None), } - -def main(hf_ckpt_path, save_path, n_experts, mp): +def process_tensor_name(name: str) -> str: """ - Converts and saves model checkpoint files into a specified format. - + Process tensor name by removing prefixes and replacing common patterns. + Args: - hf_ckpt_path (str): Path to the directory containing the input checkpoint files. - save_path (str): Path to the directory where the converted checkpoint files will be saved. - n_experts (int): Total number of experts in the model. - mp (int): Model parallelism factor. + name: Original tensor name Returns: - None + Processed tensor name """ + if name.startswith("model."): + name = name[len("model."):] + + replacements = { + "self_attn": "attn", + "mlp": "ffn", + "weight_scale_inv": "scale", + "e_score_correction_bias": "bias" + } + + for old, new in replacements.items(): + name = name.replace(old, new) + + return name + +def shard_tensor(param: torch.Tensor, mp_idx: int, mp_count: int, dim: int) -> torch.Tensor: + """ + Shard a tensor along specified dimension for model parallelism. + + Args: + param: Input tensor to shard + mp_idx: Index of current model parallel rank + mp_count: Total number of model parallel ranks + dim: Dimension along which to shard + + Returns: + Sharded tensor slice + """ + if param.size(dim) % mp_count != 0: + raise ValueError(f"Tensor size {param.size(dim)} not divisible by mp_count {mp_count}") + + shard_size = param.size(dim) // mp_count + return param.narrow(dim, mp_idx * shard_size, shard_size).contiguous() + +def convert_checkpoint( + hf_ckpt_path: Union[str, Path], + save_path: Union[str, Path], + n_experts: int, + mp: int +) -> None: + """ + Convert and save model checkpoint files into a specified format. + + Args: + hf_ckpt_path: Path to input checkpoint directory + save_path: Path to output directory for converted checkpoints + n_experts: Total number of experts in model + mp: Model parallelism factor + + Raises: + ValueError: If n_experts is not divisible by mp + FileNotFoundError: If input path doesn't exist or contain safetensors + """ + if n_experts % mp != 0: + raise ValueError(f"Number of experts ({n_experts}) must be divisible by model parallel size ({mp})") + + hf_ckpt_path = Path(hf_ckpt_path) + save_path = Path(save_path) + + if not hf_ckpt_path.exists(): + raise FileNotFoundError(f"Checkpoint path {hf_ckpt_path} does not exist") + + safetensor_files = list(hf_ckpt_path.glob("*.safetensors")) + if not safetensor_files: + raise FileNotFoundError(f"No safetensor files found in {hf_ckpt_path}") + torch.set_num_threads(8) n_local_experts = n_experts // mp - state_dicts = [{} for _ in range(mp)] + state_dicts: List[StateDict] = [{} for _ in range(mp)] - for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): + # Process each checkpoint file + for file_path in tqdm(safetensor_files, desc="Processing checkpoint files"): with safe_open(file_path, framework="pt", device="cpu") as f: for name in f.keys(): if "model.layers.61" in name: continue + param: torch.Tensor = f.get_tensor(name) - if name.startswith("model."): - name = name[len("model."):] - name = name.replace("self_attn", "attn") - name = name.replace("mlp", "ffn") - name = name.replace("weight_scale_inv", "scale") - name = name.replace("e_score_correction_bias", "bias") + name = process_tensor_name(name) + key = name.split(".")[-2] - assert key in mapping - new_key, dim = mapping[key] + if key not in TENSOR_MAPPING: + raise ValueError(f"Unknown tensor key: {key}") + + new_key, dim = TENSOR_MAPPING[key] name = name.replace(key, new_key) + + # Distribute tensors across model parallel ranks for i in range(mp): new_param = param if "experts" in name and "shared_experts" not in name: idx = int(name.split(".")[-3]) - if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: + if not (i * n_local_experts <= idx < (i + 1) * n_local_experts): continue elif dim is not None: - assert param.size(dim) % mp == 0 - shard_size = param.size(dim) // mp - new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() + new_param = shard_tensor(param, i, mp, dim) state_dicts[i][name] = new_param - os.makedirs(save_path, exist_ok=True) + # Save converted checkpoints + save_path.mkdir(parents=True, exist_ok=True) + + for i in trange(mp, desc="Saving converted checkpoints"): + output_file = save_path / f"model{i}-mp{mp}.safetensors" + save_file(state_dicts[i], str(output_file)) - for i in trange(mp): - save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) - - for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): - new_file_path = os.path.join(save_path, os.path.basename(file_path)) - shutil.copyfile(file_path, new_file_path) + # Copy tokenizer files + for file_path in hf_ckpt_path.glob("*token*"): + shutil.copyfile(file_path, save_path / file_path.name) +def main(): + """Parse command line arguments and run the conversion.""" + parser = ArgumentParser(description="Convert HuggingFace checkpoints to custom format") + parser.add_argument("--hf-ckpt-path", type=str, required=True, + help="Path to input HuggingFace checkpoint directory") + parser.add_argument("--save-path", type=str, required=True, + help="Path to output directory for converted checkpoints") + parser.add_argument("--n-experts", type=int, required=True, + help="Total number of experts in the model") + parser.add_argument("--model-parallel", type=int, required=True, + help="Model parallelism factor") + + args = parser.parse_args() + + try: + convert_checkpoint(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) + except Exception as e: + print(f"Error during conversion: {str(e)}") + raise if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--hf-ckpt-path", type=str, required=True) - parser.add_argument("--save-path", type=str, required=True) - parser.add_argument("--n-experts", type=int, required=True) - parser.add_argument("--model-parallel", type=int, required=True) - args = parser.parse_args() - assert args.n_experts % args.model_parallel == 0 - main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) + main() \ No newline at end of file