mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-22 21:58:58 -05:00
Refactor checkpoint conversion script for improved readability and efficiency
This commit is contained in:
parent
2f7b80eece
commit
897291478c
@ -2,6 +2,7 @@ import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import torch
|
||||
@ -43,46 +44,55 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
assert mp > 0, "Model parallelism (mp) must be greater than 0"
|
||||
|
||||
torch.set_num_threads(8)
|
||||
n_local_experts = n_experts // mp
|
||||
state_dicts = [{} for _ in range(mp)]
|
||||
|
||||
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
||||
hf_ckpt_path = Path(hf_ckpt_path)
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for file_path in tqdm(hf_ckpt_path.glob("*.safetensors")):
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for name in f.keys():
|
||||
if "model.layers.61" in name:
|
||||
continue
|
||||
|
||||
param: torch.Tensor = f.get_tensor(name)
|
||||
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
name = name.replace("self_attn", "attn")
|
||||
name = name.replace("mlp", "ffn")
|
||||
name = name.replace("weight_scale_inv", "scale")
|
||||
name = name.replace("e_score_correction_bias", "bias")
|
||||
|
||||
name = (
|
||||
name.replace("self_attn", "attn")
|
||||
.replace("mlp", "ffn")
|
||||
.replace("weight_scale_inv", "scale")
|
||||
.replace("e_score_correction_bias", "bias")
|
||||
)
|
||||
|
||||
key = name.split(".")[-2]
|
||||
assert key in mapping, f"Key {key} not found in mapping"
|
||||
new_key, dim = mapping[key]
|
||||
name = name.replace(key, new_key)
|
||||
for i in range(mp):
|
||||
new_param = param
|
||||
|
||||
if "experts" in name and "shared_experts" not in name:
|
||||
idx = int(name.split(".")[-3])
|
||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
||||
continue
|
||||
target_index = idx // n_local_experts
|
||||
if target_index < mp:
|
||||
state_dicts[target_index][name] = param
|
||||
elif dim is not None:
|
||||
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
||||
shard_size = param.size(dim) // mp
|
||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
||||
state_dicts[i][name] = new_param
|
||||
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
for i in range(mp):
|
||||
state_dicts[i][name] = param[:, i * shard_size : (i + 1) * shard_size] if dim == 1 else param[i * shard_size : (i + 1) * shard_size]
|
||||
|
||||
for i in trange(mp):
|
||||
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
||||
save_file(state_dicts[i], save_path / f"model{i}-mp{mp}.safetensors")
|
||||
|
||||
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
||||
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
||||
shutil.copyfile(file_path, new_file_path)
|
||||
for file_path in hf_ckpt_path.glob("*token*"):
|
||||
shutil.copyfile(file_path, save_path / file_path.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -92,5 +102,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--n-experts", type=int, required=True)
|
||||
parser.add_argument("--model-parallel", type=int, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
||||
|
||||
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
||||
|
Loading…
Reference in New Issue
Block a user