Merge pull request #2 from wowrakibul/fix/convert-py-improvements

Improve convert.py with error handling and code optimization
This commit is contained in:
Wow Rakibul 2025-02-09 02:02:17 +06:00 committed by GitHub
commit 361d0bcc1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,6 @@ import shutil
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm, trange from tqdm import tqdm, trange
import torch import torch
from safetensors.torch import safe_open, save_file from safetensors.torch import safe_open, save_file
@ -30,7 +29,7 @@ mapping = {
} }
def main(hf_ckpt_path, save_path, n_experts, mp): def main(hf_ckpt_path: str, save_path: str, n_experts: int, mp: int) -> None:
""" """
Converts and saves model checkpoint files into a specified format. Converts and saves model checkpoint files into a specified format.
@ -43,46 +42,50 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
Returns: Returns:
None None
""" """
torch.set_num_threads(8) try:
n_local_experts = n_experts // mp torch.set_num_threads(8)
state_dicts = [{} for _ in range(mp)] n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f: with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys(): for name in f.keys():
if "model.layers.61" in name: if "model.layers.61" in name:
continue continue
param: torch.Tensor = f.get_tensor(name) param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."): if name.startswith("model."):
name = name[len("model."):] name = name[len("model."):]
name = name.replace("self_attn", "attn") name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn") name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale") name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias") name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2] key = name.split(".")[-2]
assert key in mapping, f"Key {key} not found in mapping" assert key in mapping, f"Key {key} not found in mapping"
new_key, dim = mapping[key] new_key, dim = mapping[key]
name = name.replace(key, new_key) name = name.replace(key, new_key)
for i in range(mp): for i in range(mp):
new_param = param new_param = param
if "experts" in name and "shared_experts" not in name: if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3]) idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue continue
elif dim is not None: elif dim is not None:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
for i in trange(mp): for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path)) new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path) shutil.copyfile(file_path, new_file_path)
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__": if __name__ == "__main__":