Revert "optimized"

This reverts commit 3f5a2ebfc9.
This commit is contained in:
CodingParadigm1 2025-02-02 03:04:15 -07:00
parent 3f5a2ebfc9
commit 60e3466ebf

View File

@ -1,14 +1,12 @@
import os import os
import shutil import shutil
import mmap
import threading
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from glob import glob
from tqdm import tqdm, trange from tqdm import tqdm, trange
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch import torch
from safetensors.torch import safe_open, save_file from safetensors.torch import safe_open, save_file
from collections import defaultdict
mapping = { mapping = {
"embed_tokens": ("embed", 0), "embed_tokens": ("embed", 0),
@ -31,89 +29,61 @@ mapping = {
"scale": ("scale", None), "scale": ("scale", None),
} }
# Thread Lock for Safe Dictionary Access
state_lock = threading.Lock()
def fast_copy(src: Path, dst: Path):
"""Efficiently copies large files using shutil for optimal memory usage"""
if dst.exists():
dst.unlink() # Remove file if it already exists
if src.stat().st_size < 10 * 1024 * 1024: # If file < 10MB, use shutil
shutil.copyfile(src, dst)
else:
with open(src, "rb") as f_src, open(dst, "wb") as f_dst:
shutil.copyfileobj(f_src, f_dst, length=16*1024*1024)
def copy_token_file(file_path, save_path):
"""Helper function for parallel copying of token files"""
fast_copy(file_path, Path(save_path) / file_path.name)
def inner_safe_open(name: str, f, mp, state_dicts, n_local_experts):
"""Processes tensor files and maps keys correctly"""
with torch.no_grad():
param: torch.Tensor = f.get_tensor(name)
name = name[len("model."):] if name.startswith("model.") else name
name = name.replace("self_attn", "attn").replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale").replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
# Lock to avoid race conditions
with state_lock:
state_dicts[i][name] = new_param
def process_file(file_path, mp, state_dicts, n_local_experts):
"""Processes a single safetensor file"""
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" not in name:
inner_safe_open(name, f, mp, state_dicts, n_local_experts)
def main(hf_ckpt_path, save_path, n_experts, mp): def main(hf_ckpt_path, save_path, n_experts, mp):
"""Converts and saves model checkpoint files into a specified format.""" """
Converts and saves model checkpoint files into a specified format.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns:
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
# Use defaultdict to prevent key errors in multi-threading for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
state_dicts = [defaultdict(dict) for _ in range(mp)] with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
file_list = list(Path(hf_ckpt_path).glob("*.safetensors")) if "model.layers.61" in name:
token_files = list(Path(hf_ckpt_path).glob("*token*")) continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
Path(save_path).mkdir(parents=True, exist_ok=True) os.makedirs(save_path, exist_ok=True)
# Parallel Processing with ThreadPoolExecutor for i in trange(mp):
with ThreadPoolExecutor() as executor:
futures = {
executor.submit(process_file, file, mp, state_dicts, n_local_experts): file
for file in file_list
}
for future in tqdm(as_completed(futures), desc="Processing safetensors", total=len(file_list)):
future.result() # Ensure exceptions are raised
# Save processed model shards
for i in trange(mp, desc="Saving model shards"):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
# Parallel Token File Copying for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
with ThreadPoolExecutor() as executor: new_file_path = os.path.join(save_path, os.path.basename(file_path))
futures = { shutil.copyfile(file_path, new_file_path)
executor.submit(copy_token_file, file, save_path): file
for file in token_files
}
for future in tqdm(as_completed(futures), desc="Copying token files", total=len(token_files)):
future.result() # Ensure exceptions are raised
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
@ -122,6 +92,5 @@ if __name__ == "__main__":
parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args() args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0
assert args.n_experts % args.model_parallel == 0, "n_experts must be divisible by model_parallel"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)