2024-12-26 06:01:57 -05:00
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
from glob import glob
|
2025-01-27 22:50:16 -05:00
|
|
|
from pathlib import Path
|
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
2024-12-26 06:01:57 -05:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from safetensors.torch import safe_open, save_file
|
2025-01-27 22:50:16 -05:00
|
|
|
from tqdm import tqdm, trange
|
2024-12-26 06:01:57 -05:00
|
|
|
|
2025-01-27 22:50:16 -05:00
|
|
|
# Constants and type definitions
|
|
|
|
TensorMapping = Dict[str, Tuple[str, Optional[int]]]
|
|
|
|
StateDict = Dict[str, torch.Tensor]
|
2024-12-26 06:01:57 -05:00
|
|
|
|
2025-01-27 22:50:16 -05:00
|
|
|
# Define mapping as a constant at module level
|
|
|
|
TENSOR_MAPPING: TensorMapping = {
|
2024-12-26 06:01:57 -05:00
|
|
|
"embed_tokens": ("embed", 0),
|
|
|
|
"input_layernorm": ("attn_norm", None),
|
|
|
|
"post_attention_layernorm": ("ffn_norm", None),
|
|
|
|
"q_proj": ("wq", 0),
|
|
|
|
"q_a_proj": ("wq_a", None),
|
|
|
|
"q_a_layernorm": ("q_norm", None),
|
|
|
|
"q_b_proj": ("wq_b", 0),
|
|
|
|
"kv_a_proj_with_mqa": ("wkv_a", None),
|
|
|
|
"kv_a_layernorm": ("kv_norm", None),
|
|
|
|
"kv_b_proj": ("wkv_b", 0),
|
|
|
|
"o_proj": ("wo", 1),
|
|
|
|
"gate": ("gate", None),
|
|
|
|
"gate_proj": ("w1", 0),
|
|
|
|
"down_proj": ("w2", 1),
|
|
|
|
"up_proj": ("w3", 0),
|
|
|
|
"norm": ("norm", None),
|
|
|
|
"lm_head": ("head", 0),
|
|
|
|
"scale": ("scale", None),
|
|
|
|
}
|
|
|
|
|
2025-01-27 22:50:16 -05:00
|
|
|
def process_tensor_name(name: str) -> str:
|
2025-01-05 13:18:18 -05:00
|
|
|
"""
|
2025-01-27 22:50:16 -05:00
|
|
|
Process tensor name by removing prefixes and replacing common patterns.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: Original tensor name
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Processed tensor name
|
|
|
|
"""
|
|
|
|
if name.startswith("model."):
|
|
|
|
name = name[len("model."):]
|
|
|
|
|
|
|
|
replacements = {
|
|
|
|
"self_attn": "attn",
|
|
|
|
"mlp": "ffn",
|
|
|
|
"weight_scale_inv": "scale",
|
|
|
|
"e_score_correction_bias": "bias"
|
|
|
|
}
|
|
|
|
|
|
|
|
for old, new in replacements.items():
|
|
|
|
name = name.replace(old, new)
|
|
|
|
|
|
|
|
return name
|
2025-01-05 13:18:18 -05:00
|
|
|
|
2025-01-27 22:50:16 -05:00
|
|
|
def shard_tensor(param: torch.Tensor, mp_idx: int, mp_count: int, dim: int) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Shard a tensor along specified dimension for model parallelism.
|
|
|
|
|
2025-01-05 13:18:18 -05:00
|
|
|
Args:
|
2025-01-27 22:50:16 -05:00
|
|
|
param: Input tensor to shard
|
|
|
|
mp_idx: Index of current model parallel rank
|
|
|
|
mp_count: Total number of model parallel ranks
|
|
|
|
dim: Dimension along which to shard
|
2025-01-05 13:18:18 -05:00
|
|
|
|
|
|
|
Returns:
|
2025-01-27 22:50:16 -05:00
|
|
|
Sharded tensor slice
|
|
|
|
"""
|
|
|
|
if param.size(dim) % mp_count != 0:
|
|
|
|
raise ValueError(f"Tensor size {param.size(dim)} not divisible by mp_count {mp_count}")
|
|
|
|
|
|
|
|
shard_size = param.size(dim) // mp_count
|
|
|
|
return param.narrow(dim, mp_idx * shard_size, shard_size).contiguous()
|
|
|
|
|
|
|
|
def convert_checkpoint(
|
|
|
|
hf_ckpt_path: Union[str, Path],
|
|
|
|
save_path: Union[str, Path],
|
|
|
|
n_experts: int,
|
|
|
|
mp: int
|
|
|
|
) -> None:
|
2025-01-05 13:18:18 -05:00
|
|
|
"""
|
2025-01-27 22:50:16 -05:00
|
|
|
Convert and save model checkpoint files into a specified format.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
hf_ckpt_path: Path to input checkpoint directory
|
|
|
|
save_path: Path to output directory for converted checkpoints
|
|
|
|
n_experts: Total number of experts in model
|
|
|
|
mp: Model parallelism factor
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError: If n_experts is not divisible by mp
|
|
|
|
FileNotFoundError: If input path doesn't exist or contain safetensors
|
|
|
|
"""
|
|
|
|
if n_experts % mp != 0:
|
|
|
|
raise ValueError(f"Number of experts ({n_experts}) must be divisible by model parallel size ({mp})")
|
|
|
|
|
|
|
|
hf_ckpt_path = Path(hf_ckpt_path)
|
|
|
|
save_path = Path(save_path)
|
|
|
|
|
|
|
|
if not hf_ckpt_path.exists():
|
|
|
|
raise FileNotFoundError(f"Checkpoint path {hf_ckpt_path} does not exist")
|
|
|
|
|
|
|
|
safetensor_files = list(hf_ckpt_path.glob("*.safetensors"))
|
|
|
|
if not safetensor_files:
|
|
|
|
raise FileNotFoundError(f"No safetensor files found in {hf_ckpt_path}")
|
|
|
|
|
2024-12-26 06:01:57 -05:00
|
|
|
torch.set_num_threads(8)
|
|
|
|
n_local_experts = n_experts // mp
|
2025-01-27 22:50:16 -05:00
|
|
|
state_dicts: List[StateDict] = [{} for _ in range(mp)]
|
2024-12-26 06:01:57 -05:00
|
|
|
|
2025-01-27 22:50:16 -05:00
|
|
|
# Process each checkpoint file
|
|
|
|
for file_path in tqdm(safetensor_files, desc="Processing checkpoint files"):
|
2024-12-26 06:01:57 -05:00
|
|
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
|
|
|
for name in f.keys():
|
|
|
|
if "model.layers.61" in name:
|
|
|
|
continue
|
2025-01-27 22:50:16 -05:00
|
|
|
|
2024-12-26 06:01:57 -05:00
|
|
|
param: torch.Tensor = f.get_tensor(name)
|
2025-01-27 22:50:16 -05:00
|
|
|
name = process_tensor_name(name)
|
|
|
|
|
2024-12-26 06:01:57 -05:00
|
|
|
key = name.split(".")[-2]
|
2025-01-27 22:50:16 -05:00
|
|
|
if key not in TENSOR_MAPPING:
|
|
|
|
raise ValueError(f"Unknown tensor key: {key}")
|
|
|
|
|
|
|
|
new_key, dim = TENSOR_MAPPING[key]
|
2024-12-26 06:01:57 -05:00
|
|
|
name = name.replace(key, new_key)
|
2025-01-27 22:50:16 -05:00
|
|
|
|
|
|
|
# Distribute tensors across model parallel ranks
|
2024-12-26 06:01:57 -05:00
|
|
|
for i in range(mp):
|
|
|
|
new_param = param
|
|
|
|
if "experts" in name and "shared_experts" not in name:
|
|
|
|
idx = int(name.split(".")[-3])
|
2025-01-27 22:50:16 -05:00
|
|
|
if not (i * n_local_experts <= idx < (i + 1) * n_local_experts):
|
2024-12-26 06:01:57 -05:00
|
|
|
continue
|
|
|
|
elif dim is not None:
|
2025-01-27 22:50:16 -05:00
|
|
|
new_param = shard_tensor(param, i, mp, dim)
|
2024-12-26 06:01:57 -05:00
|
|
|
state_dicts[i][name] = new_param
|
|
|
|
|
2025-01-27 22:50:16 -05:00
|
|
|
# Save converted checkpoints
|
|
|
|
save_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
for i in trange(mp, desc="Saving converted checkpoints"):
|
|
|
|
output_file = save_path / f"model{i}-mp{mp}.safetensors"
|
|
|
|
save_file(state_dicts[i], str(output_file))
|
2024-12-26 06:01:57 -05:00
|
|
|
|
2025-01-27 22:50:16 -05:00
|
|
|
# Copy tokenizer files
|
|
|
|
for file_path in hf_ckpt_path.glob("*token*"):
|
|
|
|
shutil.copyfile(file_path, save_path / file_path.name)
|
2024-12-26 06:01:57 -05:00
|
|
|
|
2025-01-27 22:50:16 -05:00
|
|
|
def main():
|
|
|
|
"""Parse command line arguments and run the conversion."""
|
|
|
|
parser = ArgumentParser(description="Convert HuggingFace checkpoints to custom format")
|
|
|
|
parser.add_argument("--hf-ckpt-path", type=str, required=True,
|
|
|
|
help="Path to input HuggingFace checkpoint directory")
|
|
|
|
parser.add_argument("--save-path", type=str, required=True,
|
|
|
|
help="Path to output directory for converted checkpoints")
|
|
|
|
parser.add_argument("--n-experts", type=int, required=True,
|
|
|
|
help="Total number of experts in the model")
|
|
|
|
parser.add_argument("--model-parallel", type=int, required=True,
|
|
|
|
help="Model parallelism factor")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
try:
|
|
|
|
convert_checkpoint(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error during conversion: {str(e)}")
|
|
|
|
raise
|
2024-12-26 06:01:57 -05:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2025-01-27 22:50:16 -05:00
|
|
|
main()
|