reduced breaking changes

This commit is contained in:
CodingParadigm1 2025-02-02 03:49:23 -07:00
parent 60e3466ebf
commit 1cb3e2a63f

View File

@ -3,7 +3,7 @@ import shutil
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm, trange from tqdm import tqdm, trange
import asyncio as sync
import torch import torch
from safetensors.torch import safe_open, save_file from safetensors.torch import safe_open, save_file
@ -29,6 +29,32 @@ mapping = {
"scale": ("scale", None), "scale": ("scale", None),
} }
async def set_param(param, name, i, n_local_experts, mp, state_dicts, dim):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
return
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
async def inner_safe_open(name, f, state_dicts, mp, n_local_experts):
if "model.layers.61" not in name:
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
await sync.gather(*(set_param(param, name, i, n_local_experts, mp, state_dicts, dim) for i in range(mp)))
def main(hf_ckpt_path, save_path, n_experts, mp): def main(hf_ckpt_path, save_path, n_experts, mp):
""" """
@ -46,43 +72,20 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
torch.set_num_threads(8) torch.set_num_threads(8)
n_local_experts = n_experts // mp n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)] state_dicts = [{} for _ in range(mp)]
tasks = []
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f: async with sync.to_thread(safe_open, file_path, framework="pt", device="cpu") as f:
for name in f.keys(): await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys()))
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
for i in trange(mp): await sync.gather(*(sync.to_thread(save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))) for i in trange(mp)))
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
async def set_file_path(file_path):
await sync.to_thread(shutil.copyfile, file_path, os.path.join(save_path, os.path.basename(file_path)))
await sync.gather(*(set_file_path(file_path) for file_path in glob(os.path.join(hf_ckpt_path, "*token*"))))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
if __name__ == "__main__": if __name__ == "__main__":
@ -93,4 +96,4 @@ if __name__ == "__main__":
parser.add_argument("--model-parallel", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args() args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0 assert args.n_experts % args.model_parallel == 0
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))