diff --git a/inference/convert.py b/inference/convert.py index c606ce8..3277cea 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -1,14 +1,24 @@ import os import shutil +import logging from argparse import ArgumentParser from glob import glob -from tqdm import tqdm, trange +from pathlib import Path +from typing import Dict, Tuple, List, Optional +from tqdm import tqdm import torch from safetensors.torch import safe_open, save_file +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) -mapping = { +# Type aliases +TensorMapping = Dict[str, Tuple[str, Optional[int]]] + +# Configuration mapping with type hints +MAPPING: TensorMapping = { "embed_tokens": ("embed", 0), "input_layernorm": ("attn_norm", None), "post_attention_layernorm": ("ffn_norm", None), @@ -29,68 +39,168 @@ mapping = { "scale": ("scale", None), } +def validate_paths(hf_ckpt_path: str, save_path: str) -> None: + """Validate input and output paths.""" + if not os.path.isdir(hf_ckpt_path): + raise ValueError(f"Input directory {hf_ckpt_path} does not exist") + + os.makedirs(save_path, exist_ok=True) + if not os.access(save_path, os.W_OK): + raise PermissionError(f"No write permission for output directory {save_path}") -def main(hf_ckpt_path, save_path, n_experts, mp): +def process_tensor_name(name: str) -> str: + """Process and normalize tensor names.""" + # Remove 'model.' prefix if present + if name.startswith("model."): + name = name[len("model."):] + + # Replace specific patterns + replacements = { + "self_attn": "attn", + "mlp": "ffn", + "weight_scale_inv": "scale", + "e_score_correction_bias": "bias" + } + + for old, new in replacements.items(): + name = name.replace(old, new) + + return name + +def split_tensor(param: torch.Tensor, dim: Optional[int], mp: int, idx: int) -> torch.Tensor: + """Split tensor for model parallelism.""" + if dim is None: + return param + + if param.size(dim) % mp != 0: + raise ValueError(f"Dimension {dim} of tensor with shape {param.shape} " + f"is not divisible by model parallelism factor {mp}") + + shard_size = param.size(dim) // mp + return param.narrow(dim, idx * shard_size, shard_size).contiguous() + +def process_checkpoint_files( + hf_ckpt_path: str, + mp: int, + n_local_experts: int, + state_dicts: List[Dict[str, torch.Tensor]] +) -> None: + """Process all checkpoint files and populate state dictionaries.""" + for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors")), + desc="Processing checkpoint files"): + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + for name in tqdm(f.keys(), desc=f"Processing {os.path.basename(file_path)}", leave=False): + if "model.layers.61" in name: + logger.debug(f"Skipping layer 61 tensor: {name}") + continue + + param = f.get_tensor(name) + processed_name = process_tensor_name(name) + + key = processed_name.split(".")[-2] + if key not in MAPPING: + raise KeyError(f"Unexpected tensor key: {key} in tensor {name}") + + new_key, dim = MAPPING[key] + final_name = processed_name.replace(key, new_key) + + for i in range(mp): + if "experts" in final_name and "shared_experts" not in final_name: + expert_idx = int(final_name.split(".")[-3]) + if not (i * n_local_experts <= expert_idx < (i + 1) * n_local_experts): + continue + + split_param = split_tensor(param, dim, mp, i) + state_dicts[i][final_name] = split_param + except Exception as e: + logger.error(f"Error processing file {file_path}: {str(e)}") + raise + +def save_output_files( + state_dicts: List[Dict[str, torch.Tensor]], + save_path: str, + mp: int, + hf_ckpt_path: str +) -> None: + """Save processed state dictionaries and copy token files.""" + for i in tqdm(range(mp), desc="Saving output files"): + output_file = os.path.join(save_path, f"model{i}-mp{mp}.safetensors") + save_file(state_dicts[i], output_file, metadata={"format": "pt"}) + + # Copy token-related files + for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): + try: + shutil.copy(file_path, os.path.join(save_path, os.path.basename(file_path))) + except IOError as e: + logger.error(f"Error copying file {file_path}: {str(e)}") + +def main( + hf_ckpt_path: str, + save_path: str, + n_experts: int, + mp: int +) -> None: """ - Converts and saves model checkpoint files into a specified format. - + Convert and split model checkpoints for distributed training. + Args: - hf_ckpt_path (str): Path to the directory containing the input checkpoint files. - save_path (str): Path to the directory where the converted checkpoint files will be saved. - n_experts (int): Total number of experts in the model. - mp (int): Model parallelism factor. - - Returns: - None + hf_ckpt_path: Path to HuggingFace format checkpoint directory + save_path: Output directory for converted checkpoints + n_experts: Total number of experts in the model + mp: Model parallelism factor """ torch.set_num_threads(8) + validate_paths(hf_ckpt_path, save_path) + + if n_experts % mp != 0: + raise ValueError(f"Number of experts {n_experts} must be divisible by model parallelism factor {mp}") + n_local_experts = n_experts // mp state_dicts = [{} for _ in range(mp)] - - for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): - with safe_open(file_path, framework="pt", device="cpu") as f: - for name in f.keys(): - if "model.layers.61" in name: - continue - param: torch.Tensor = f.get_tensor(name) - if name.startswith("model."): - name = name[len("model."):] - name = name.replace("self_attn", "attn") - name = name.replace("mlp", "ffn") - name = name.replace("weight_scale_inv", "scale") - name = name.replace("e_score_correction_bias", "bias") - key = name.split(".")[-2] - assert key in mapping - new_key, dim = mapping[key] - name = name.replace(key, new_key) - for i in range(mp): - new_param = param - if "experts" in name and "shared_experts" not in name: - idx = int(name.split(".")[-3]) - if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: - continue - elif dim is not None: - assert param.size(dim) % mp == 0 - shard_size = param.size(dim) // mp - new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() - state_dicts[i][name] = new_param - - os.makedirs(save_path, exist_ok=True) - - for i in trange(mp): - save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) - - for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): - new_file_path = os.path.join(save_path, os.path.basename(file_path)) - shutil.copyfile(file_path, new_file_path) - + + process_checkpoint_files(hf_ckpt_path, mp, n_local_experts, state_dicts) + save_output_files(state_dicts, save_path, mp, hf_ckpt_path) + + logger.info(f"Successfully converted checkpoints. Output saved to {save_path}") if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--hf-ckpt-path", type=str, required=True) - parser.add_argument("--save-path", type=str, required=True) - parser.add_argument("--n-experts", type=int, required=True) - parser.add_argument("--model-parallel", type=int, required=True) + parser = ArgumentParser(description="Convert HuggingFace checkpoints to distributed format") + parser.add_argument( + "--hf-ckpt-path", + type=str, + required=True, + help="Path to input HuggingFace checkpoint directory" + ) + parser.add_argument( + "--save-path", + type=str, + required=True, + help="Output directory for converted checkpoints" + ) + parser.add_argument( + "--n-experts", + type=int, + required=True, + help="Total number of experts in the model" + ) + parser.add_argument( + "--model-parallel", + type=int, + required=True, + dest="model_parallel", + help="Model parallelism factor" + ) + args = parser.parse_args() - assert args.n_experts % args.model_parallel == 0 - main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) + + try: + main( + args.hf_ckpt_path, + args.save_path, + args.n_experts, + args.model_parallel + ) + except Exception as e: + logger.error(f"Conversion failed: {str(e)}") + raise