From 60e3466ebffd308d3e9ddab58aa3934d857d8d05 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Sun, 2 Feb 2025 03:04:15 -0700 Subject: [PATCH] Revert "optimized" This reverts commit 3f5a2ebfc93005dafe7cf702e9de51b636ceb77a. --- inference/convert.py | 133 +++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 82 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 1b393d8..c606ce8 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -1,14 +1,12 @@ import os import shutil -import mmap -import threading from argparse import ArgumentParser -from pathlib import Path +from glob import glob from tqdm import tqdm, trange -from concurrent.futures import ThreadPoolExecutor, as_completed + import torch from safetensors.torch import safe_open, save_file -from collections import defaultdict + mapping = { "embed_tokens": ("embed", 0), @@ -31,89 +29,61 @@ mapping = { "scale": ("scale", None), } -# Thread Lock for Safe Dictionary Access -state_lock = threading.Lock() - -def fast_copy(src: Path, dst: Path): - """Efficiently copies large files using shutil for optimal memory usage""" - if dst.exists(): - dst.unlink() # Remove file if it already exists - if src.stat().st_size < 10 * 1024 * 1024: # If file < 10MB, use shutil - shutil.copyfile(src, dst) - else: - with open(src, "rb") as f_src, open(dst, "wb") as f_dst: - shutil.copyfileobj(f_src, f_dst, length=16*1024*1024) - -def copy_token_file(file_path, save_path): - """Helper function for parallel copying of token files""" - fast_copy(file_path, Path(save_path) / file_path.name) - -def inner_safe_open(name: str, f, mp, state_dicts, n_local_experts): - """Processes tensor files and maps keys correctly""" - with torch.no_grad(): - param: torch.Tensor = f.get_tensor(name) - name = name[len("model."):] if name.startswith("model.") else name - name = name.replace("self_attn", "attn").replace("mlp", "ffn") - name = name.replace("weight_scale_inv", "scale").replace("e_score_correction_bias", "bias") - key = name.split(".")[-2] - assert key in mapping - new_key, dim = mapping[key] - name = name.replace(key, new_key) - - for i in range(mp): - new_param = param - if "experts" in name and "shared_experts" not in name: - idx = int(name.split(".")[-3]) - if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: - continue - elif dim is not None: - shard_size = param.size(dim) // mp - new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() - - # Lock to avoid race conditions - with state_lock: - state_dicts[i][name] = new_param - -def process_file(file_path, mp, state_dicts, n_local_experts): - """Processes a single safetensor file""" - with safe_open(file_path, framework="pt", device="cpu") as f: - for name in f.keys(): - if "model.layers.61" not in name: - inner_safe_open(name, f, mp, state_dicts, n_local_experts) def main(hf_ckpt_path, save_path, n_experts, mp): - """Converts and saves model checkpoint files into a specified format.""" + """ + Converts and saves model checkpoint files into a specified format. + + Args: + hf_ckpt_path (str): Path to the directory containing the input checkpoint files. + save_path (str): Path to the directory where the converted checkpoint files will be saved. + n_experts (int): Total number of experts in the model. + mp (int): Model parallelism factor. + + Returns: + None + """ + torch.set_num_threads(8) n_local_experts = n_experts // mp + state_dicts = [{} for _ in range(mp)] - # Use defaultdict to prevent key errors in multi-threading - state_dicts = [defaultdict(dict) for _ in range(mp)] - - file_list = list(Path(hf_ckpt_path).glob("*.safetensors")) - token_files = list(Path(hf_ckpt_path).glob("*token*")) + for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): + with safe_open(file_path, framework="pt", device="cpu") as f: + for name in f.keys(): + if "model.layers.61" in name: + continue + param: torch.Tensor = f.get_tensor(name) + if name.startswith("model."): + name = name[len("model."):] + name = name.replace("self_attn", "attn") + name = name.replace("mlp", "ffn") + name = name.replace("weight_scale_inv", "scale") + name = name.replace("e_score_correction_bias", "bias") + key = name.split(".")[-2] + assert key in mapping + new_key, dim = mapping[key] + name = name.replace(key, new_key) + for i in range(mp): + new_param = param + if "experts" in name and "shared_experts" not in name: + idx = int(name.split(".")[-3]) + if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: + continue + elif dim is not None: + assert param.size(dim) % mp == 0 + shard_size = param.size(dim) // mp + new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() + state_dicts[i][name] = new_param - Path(save_path).mkdir(parents=True, exist_ok=True) + os.makedirs(save_path, exist_ok=True) - # Parallel Processing with ThreadPoolExecutor - with ThreadPoolExecutor() as executor: - futures = { - executor.submit(process_file, file, mp, state_dicts, n_local_experts): file - for file in file_list - } - for future in tqdm(as_completed(futures), desc="Processing safetensors", total=len(file_list)): - future.result() # Ensure exceptions are raised - - # Save processed model shards - for i in trange(mp, desc="Saving model shards"): + for i in trange(mp): save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) - # Parallel Token File Copying - with ThreadPoolExecutor() as executor: - futures = { - executor.submit(copy_token_file, file, save_path): file - for file in token_files - } - for future in tqdm(as_completed(futures), desc="Copying token files", total=len(token_files)): - future.result() # Ensure exceptions are raised + for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): + new_file_path = os.path.join(save_path, os.path.basename(file_path)) + shutil.copyfile(file_path, new_file_path) + if __name__ == "__main__": parser = ArgumentParser() @@ -122,6 +92,5 @@ if __name__ == "__main__": parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() - - assert args.n_experts % args.model_parallel == 0, "n_experts must be divisible by model_parallel" + assert args.n_experts % args.model_parallel == 0 main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)