From 3f5a2ebfc93005dafe7cf702e9de51b636ceb77a Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Fri, 31 Jan 2025 17:15:10 -0700 Subject: [PATCH 1/9] optimized added parallel file processing, add faster tensor ops --- inference/convert.py | 133 ++++++++++++++++++++++++++----------------- 1 file changed, 82 insertions(+), 51 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index c606ce8..1b393d8 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -1,12 +1,14 @@ import os import shutil +import mmap +import threading from argparse import ArgumentParser -from glob import glob +from pathlib import Path from tqdm import tqdm, trange - +from concurrent.futures import ThreadPoolExecutor, as_completed import torch from safetensors.torch import safe_open, save_file - +from collections import defaultdict mapping = { "embed_tokens": ("embed", 0), @@ -29,61 +31,89 @@ mapping = { "scale": ("scale", None), } +# Thread Lock for Safe Dictionary Access +state_lock = threading.Lock() + +def fast_copy(src: Path, dst: Path): + """Efficiently copies large files using shutil for optimal memory usage""" + if dst.exists(): + dst.unlink() # Remove file if it already exists + if src.stat().st_size < 10 * 1024 * 1024: # If file < 10MB, use shutil + shutil.copyfile(src, dst) + else: + with open(src, "rb") as f_src, open(dst, "wb") as f_dst: + shutil.copyfileobj(f_src, f_dst, length=16*1024*1024) + +def copy_token_file(file_path, save_path): + """Helper function for parallel copying of token files""" + fast_copy(file_path, Path(save_path) / file_path.name) + +def inner_safe_open(name: str, f, mp, state_dicts, n_local_experts): + """Processes tensor files and maps keys correctly""" + with torch.no_grad(): + param: torch.Tensor = f.get_tensor(name) + name = name[len("model."):] if name.startswith("model.") else name + name = name.replace("self_attn", "attn").replace("mlp", "ffn") + name = name.replace("weight_scale_inv", "scale").replace("e_score_correction_bias", "bias") + key = name.split(".")[-2] + assert key in mapping + new_key, dim = mapping[key] + name = name.replace(key, new_key) + + for i in range(mp): + new_param = param + if "experts" in name and "shared_experts" not in name: + idx = int(name.split(".")[-3]) + if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: + continue + elif dim is not None: + shard_size = param.size(dim) // mp + new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() + + # Lock to avoid race conditions + with state_lock: + state_dicts[i][name] = new_param + +def process_file(file_path, mp, state_dicts, n_local_experts): + """Processes a single safetensor file""" + with safe_open(file_path, framework="pt", device="cpu") as f: + for name in f.keys(): + if "model.layers.61" not in name: + inner_safe_open(name, f, mp, state_dicts, n_local_experts) def main(hf_ckpt_path, save_path, n_experts, mp): - """ - Converts and saves model checkpoint files into a specified format. - - Args: - hf_ckpt_path (str): Path to the directory containing the input checkpoint files. - save_path (str): Path to the directory where the converted checkpoint files will be saved. - n_experts (int): Total number of experts in the model. - mp (int): Model parallelism factor. - - Returns: - None - """ - torch.set_num_threads(8) + """Converts and saves model checkpoint files into a specified format.""" n_local_experts = n_experts // mp - state_dicts = [{} for _ in range(mp)] - for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): - with safe_open(file_path, framework="pt", device="cpu") as f: - for name in f.keys(): - if "model.layers.61" in name: - continue - param: torch.Tensor = f.get_tensor(name) - if name.startswith("model."): - name = name[len("model."):] - name = name.replace("self_attn", "attn") - name = name.replace("mlp", "ffn") - name = name.replace("weight_scale_inv", "scale") - name = name.replace("e_score_correction_bias", "bias") - key = name.split(".")[-2] - assert key in mapping - new_key, dim = mapping[key] - name = name.replace(key, new_key) - for i in range(mp): - new_param = param - if "experts" in name and "shared_experts" not in name: - idx = int(name.split(".")[-3]) - if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: - continue - elif dim is not None: - assert param.size(dim) % mp == 0 - shard_size = param.size(dim) // mp - new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() - state_dicts[i][name] = new_param + # Use defaultdict to prevent key errors in multi-threading + state_dicts = [defaultdict(dict) for _ in range(mp)] + + file_list = list(Path(hf_ckpt_path).glob("*.safetensors")) + token_files = list(Path(hf_ckpt_path).glob("*token*")) - os.makedirs(save_path, exist_ok=True) + Path(save_path).mkdir(parents=True, exist_ok=True) - for i in trange(mp): + # Parallel Processing with ThreadPoolExecutor + with ThreadPoolExecutor() as executor: + futures = { + executor.submit(process_file, file, mp, state_dicts, n_local_experts): file + for file in file_list + } + for future in tqdm(as_completed(futures), desc="Processing safetensors", total=len(file_list)): + future.result() # Ensure exceptions are raised + + # Save processed model shards + for i in trange(mp, desc="Saving model shards"): save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) - for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): - new_file_path = os.path.join(save_path, os.path.basename(file_path)) - shutil.copyfile(file_path, new_file_path) - + # Parallel Token File Copying + with ThreadPoolExecutor() as executor: + futures = { + executor.submit(copy_token_file, file, save_path): file + for file in token_files + } + for future in tqdm(as_completed(futures), desc="Copying token files", total=len(token_files)): + future.result() # Ensure exceptions are raised if __name__ == "__main__": parser = ArgumentParser() @@ -92,5 +122,6 @@ if __name__ == "__main__": parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() - assert args.n_experts % args.model_parallel == 0 + + assert args.n_experts % args.model_parallel == 0, "n_experts must be divisible by model_parallel" main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) From 60e3466ebffd308d3e9ddab58aa3934d857d8d05 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Sun, 2 Feb 2025 03:04:15 -0700 Subject: [PATCH 2/9] Revert "optimized" This reverts commit 3f5a2ebfc93005dafe7cf702e9de51b636ceb77a. --- inference/convert.py | 133 +++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 82 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 1b393d8..c606ce8 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -1,14 +1,12 @@ import os import shutil -import mmap -import threading from argparse import ArgumentParser -from pathlib import Path +from glob import glob from tqdm import tqdm, trange -from concurrent.futures import ThreadPoolExecutor, as_completed + import torch from safetensors.torch import safe_open, save_file -from collections import defaultdict + mapping = { "embed_tokens": ("embed", 0), @@ -31,89 +29,61 @@ mapping = { "scale": ("scale", None), } -# Thread Lock for Safe Dictionary Access -state_lock = threading.Lock() - -def fast_copy(src: Path, dst: Path): - """Efficiently copies large files using shutil for optimal memory usage""" - if dst.exists(): - dst.unlink() # Remove file if it already exists - if src.stat().st_size < 10 * 1024 * 1024: # If file < 10MB, use shutil - shutil.copyfile(src, dst) - else: - with open(src, "rb") as f_src, open(dst, "wb") as f_dst: - shutil.copyfileobj(f_src, f_dst, length=16*1024*1024) - -def copy_token_file(file_path, save_path): - """Helper function for parallel copying of token files""" - fast_copy(file_path, Path(save_path) / file_path.name) - -def inner_safe_open(name: str, f, mp, state_dicts, n_local_experts): - """Processes tensor files and maps keys correctly""" - with torch.no_grad(): - param: torch.Tensor = f.get_tensor(name) - name = name[len("model."):] if name.startswith("model.") else name - name = name.replace("self_attn", "attn").replace("mlp", "ffn") - name = name.replace("weight_scale_inv", "scale").replace("e_score_correction_bias", "bias") - key = name.split(".")[-2] - assert key in mapping - new_key, dim = mapping[key] - name = name.replace(key, new_key) - - for i in range(mp): - new_param = param - if "experts" in name and "shared_experts" not in name: - idx = int(name.split(".")[-3]) - if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: - continue - elif dim is not None: - shard_size = param.size(dim) // mp - new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() - - # Lock to avoid race conditions - with state_lock: - state_dicts[i][name] = new_param - -def process_file(file_path, mp, state_dicts, n_local_experts): - """Processes a single safetensor file""" - with safe_open(file_path, framework="pt", device="cpu") as f: - for name in f.keys(): - if "model.layers.61" not in name: - inner_safe_open(name, f, mp, state_dicts, n_local_experts) def main(hf_ckpt_path, save_path, n_experts, mp): - """Converts and saves model checkpoint files into a specified format.""" + """ + Converts and saves model checkpoint files into a specified format. + + Args: + hf_ckpt_path (str): Path to the directory containing the input checkpoint files. + save_path (str): Path to the directory where the converted checkpoint files will be saved. + n_experts (int): Total number of experts in the model. + mp (int): Model parallelism factor. + + Returns: + None + """ + torch.set_num_threads(8) n_local_experts = n_experts // mp + state_dicts = [{} for _ in range(mp)] - # Use defaultdict to prevent key errors in multi-threading - state_dicts = [defaultdict(dict) for _ in range(mp)] - - file_list = list(Path(hf_ckpt_path).glob("*.safetensors")) - token_files = list(Path(hf_ckpt_path).glob("*token*")) + for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): + with safe_open(file_path, framework="pt", device="cpu") as f: + for name in f.keys(): + if "model.layers.61" in name: + continue + param: torch.Tensor = f.get_tensor(name) + if name.startswith("model."): + name = name[len("model."):] + name = name.replace("self_attn", "attn") + name = name.replace("mlp", "ffn") + name = name.replace("weight_scale_inv", "scale") + name = name.replace("e_score_correction_bias", "bias") + key = name.split(".")[-2] + assert key in mapping + new_key, dim = mapping[key] + name = name.replace(key, new_key) + for i in range(mp): + new_param = param + if "experts" in name and "shared_experts" not in name: + idx = int(name.split(".")[-3]) + if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: + continue + elif dim is not None: + assert param.size(dim) % mp == 0 + shard_size = param.size(dim) // mp + new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() + state_dicts[i][name] = new_param - Path(save_path).mkdir(parents=True, exist_ok=True) + os.makedirs(save_path, exist_ok=True) - # Parallel Processing with ThreadPoolExecutor - with ThreadPoolExecutor() as executor: - futures = { - executor.submit(process_file, file, mp, state_dicts, n_local_experts): file - for file in file_list - } - for future in tqdm(as_completed(futures), desc="Processing safetensors", total=len(file_list)): - future.result() # Ensure exceptions are raised - - # Save processed model shards - for i in trange(mp, desc="Saving model shards"): + for i in trange(mp): save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) - # Parallel Token File Copying - with ThreadPoolExecutor() as executor: - futures = { - executor.submit(copy_token_file, file, save_path): file - for file in token_files - } - for future in tqdm(as_completed(futures), desc="Copying token files", total=len(token_files)): - future.result() # Ensure exceptions are raised + for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): + new_file_path = os.path.join(save_path, os.path.basename(file_path)) + shutil.copyfile(file_path, new_file_path) + if __name__ == "__main__": parser = ArgumentParser() @@ -122,6 +92,5 @@ if __name__ == "__main__": parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() - - assert args.n_experts % args.model_parallel == 0, "n_experts must be divisible by model_parallel" + assert args.n_experts % args.model_parallel == 0 main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) From 1cb3e2a63f95d7ce90b6f34221474a050f06d66e Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Sun, 2 Feb 2025 03:49:23 -0700 Subject: [PATCH 3/9] reduced breaking changes --- inference/convert.py | 71 +++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index c606ce8..6b818db 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -3,7 +3,7 @@ import shutil from argparse import ArgumentParser from glob import glob from tqdm import tqdm, trange - +import asyncio as sync import torch from safetensors.torch import safe_open, save_file @@ -29,6 +29,32 @@ mapping = { "scale": ("scale", None), } +async def set_param(param, name, i, n_local_experts, mp, state_dicts, dim): + new_param = param + if "experts" in name and "shared_experts" not in name: + idx = int(name.split(".")[-3]) + if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: + return + elif dim is not None: + assert param.size(dim) % mp == 0 + shard_size = param.size(dim) // mp + new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() + state_dicts[i][name] = new_param +async def inner_safe_open(name, f, state_dicts, mp, n_local_experts): + if "model.layers.61" not in name: + param: torch.Tensor = f.get_tensor(name) + if name.startswith("model."): + name = name[len("model."):] + name = name.replace("self_attn", "attn") + name = name.replace("mlp", "ffn") + name = name.replace("weight_scale_inv", "scale") + name = name.replace("e_score_correction_bias", "bias") + key = name.split(".")[-2] + assert key in mapping + new_key, dim = mapping[key] + name = name.replace(key, new_key) + await sync.gather(*(set_param(param, name, i, n_local_experts, mp, state_dicts, dim) for i in range(mp))) + def main(hf_ckpt_path, save_path, n_experts, mp): """ @@ -46,43 +72,20 @@ def main(hf_ckpt_path, save_path, n_experts, mp): torch.set_num_threads(8) n_local_experts = n_experts // mp state_dicts = [{} for _ in range(mp)] - + tasks = [] for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): - with safe_open(file_path, framework="pt", device="cpu") as f: - for name in f.keys(): - if "model.layers.61" in name: - continue - param: torch.Tensor = f.get_tensor(name) - if name.startswith("model."): - name = name[len("model."):] - name = name.replace("self_attn", "attn") - name = name.replace("mlp", "ffn") - name = name.replace("weight_scale_inv", "scale") - name = name.replace("e_score_correction_bias", "bias") - key = name.split(".")[-2] - assert key in mapping - new_key, dim = mapping[key] - name = name.replace(key, new_key) - for i in range(mp): - new_param = param - if "experts" in name and "shared_experts" not in name: - idx = int(name.split(".")[-3]) - if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: - continue - elif dim is not None: - assert param.size(dim) % mp == 0 - shard_size = param.size(dim) // mp - new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() - state_dicts[i][name] = new_param + async with sync.to_thread(safe_open, file_path, framework="pt", device="cpu") as f: + await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys())) os.makedirs(save_path, exist_ok=True) - for i in trange(mp): - save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) + await sync.gather(*(sync.to_thread(save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))) for i in trange(mp))) + + async def set_file_path(file_path): + await sync.to_thread(shutil.copyfile, file_path, os.path.join(save_path, os.path.basename(file_path))) + + await sync.gather(*(set_file_path(file_path) for file_path in glob(os.path.join(hf_ckpt_path, "*token*")))) - for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): - new_file_path = os.path.join(save_path, os.path.basename(file_path)) - shutil.copyfile(file_path, new_file_path) if __name__ == "__main__": @@ -93,4 +96,4 @@ if __name__ == "__main__": parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() assert args.n_experts % args.model_parallel == 0 - main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) + sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)) From f2636dd3666f08805b8501bbcc3bee5d05ac0892 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Sun, 2 Feb 2025 03:55:47 -0700 Subject: [PATCH 4/9] updated async file reading --- inference/convert.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/inference/convert.py b/inference/convert.py index 6b818db..212e2e4 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -74,7 +74,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp): state_dicts = [{} for _ in range(mp)] tasks = [] for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): - async with sync.to_thread(safe_open, file_path, framework="pt", device="cpu") as f: + cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu") + async with cm as f: await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys())) os.makedirs(save_path, exist_ok=True) From 35244be39f41b5607c887d62acfd8fc1666d1597 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Sun, 2 Feb 2025 06:12:46 -0700 Subject: [PATCH 5/9] moved dir calls removed dir calls from for loops --- inference/convert.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 212e2e4..0310cb2 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -73,7 +73,9 @@ def main(hf_ckpt_path, save_path, n_experts, mp): n_local_experts = n_experts // mp state_dicts = [{} for _ in range(mp)] tasks = [] - for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): + tensor_dir = glob(os.path.join(hf_ckpt_path, "*.safetensors")) + token_dir = glob(os.path.join(hf_ckpt_path, "*token*")) + for file_path in tqdm(tensor_dir): cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu") async with cm as f: await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys())) @@ -85,7 +87,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): async def set_file_path(file_path): await sync.to_thread(shutil.copyfile, file_path, os.path.join(save_path, os.path.basename(file_path))) - await sync.gather(*(set_file_path(file_path) for file_path in glob(os.path.join(hf_ckpt_path, "*token*")))) + await sync.gather(*(set_file_path(file_path) for file_path in token_dir)) From 07de76f5ee9669be472d80587585960c91361580 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Sun, 2 Feb 2025 06:14:17 -0700 Subject: [PATCH 6/9] small patch removed unnecessary variable --- inference/convert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/inference/convert.py b/inference/convert.py index 0310cb2..1fc1ce0 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -72,7 +72,6 @@ def main(hf_ckpt_path, save_path, n_experts, mp): torch.set_num_threads(8) n_local_experts = n_experts // mp state_dicts = [{} for _ in range(mp)] - tasks = [] tensor_dir = glob(os.path.join(hf_ckpt_path, "*.safetensors")) token_dir = glob(os.path.join(hf_ckpt_path, "*token*")) for file_path in tqdm(tensor_dir): From 267e7ba6858ab87fcd0032b05dbb612df894d57c Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Mon, 3 Feb 2025 10:06:38 -0700 Subject: [PATCH 7/9] add functionality - applied asyncio to more files - added Parser class - made small changes --- inference/convert.py | 6 +-- inference/fp8_cast_bf16.py | 76 +++++++++++++++++++------------------- inference/generate.py | 54 +++++++++++++-------------- inference/parser.py | 12 ++++++ 4 files changed, 76 insertions(+), 72 deletions(-) create mode 100644 inference/parser.py diff --git a/inference/convert.py b/inference/convert.py index 1fc1ce0..1073ffb 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -70,10 +70,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp): None """ torch.set_num_threads(8) - n_local_experts = n_experts // mp - state_dicts = [{} for _ in range(mp)] - tensor_dir = glob(os.path.join(hf_ckpt_path, "*.safetensors")) - token_dir = glob(os.path.join(hf_ckpt_path, "*token*")) + n_local_experts,state_dicts = n_experts // mp, [{} for _ in range(mp)] + tensor_dir, token_dir = list(glob(os.path.join(hf_ckpt_path, "*.safetensors"))),list(glob(os.path.join(hf_ckpt_path, "*token*"))) for file_path in tqdm(tensor_dir): cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu") async with cm as f: diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..3047edc 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -3,13 +3,42 @@ import json from argparse import ArgumentParser from glob import glob from tqdm import tqdm - +from asyncio import gather, to_thread, run import torch from safetensors.torch import load_file, save_file from kernel import weight_dequant -def main(fp8_path, bf16_path): +def inner_tensor_file(safetensor_file): + file_name = os.path.basename(safetensor_file) + current_state_dict = load_file(safetensor_file, device="cuda") + loaded_files[file_name] = current_state_dict + new_state_dict = {} + for weight_name, weight in current_state_dict.items(): + if weight_name.endswith("_scale_inv"): + continue + elif weight.element_size() == 1: # FP8 weight + scale_inv_name = f"{weight_name}_scale_inv" + try: + # Get scale_inv from the correct file + scale_inv = get_tensor(scale_inv_name) + fp8_weight_names.append(weight_name) + new_state_dict[weight_name] = weight_dequant(weight, scale_inv) + except KeyError: + print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") + new_state_dict[weight_name] = weight + else: + new_state_dict[weight_name] = weight + new_safetensor_file = os.path.join(bf16_path, file_name) + save_file(new_state_dict, new_safetensor_file) + + # Memory management: keep only the 2 most recently used files + if len(loaded_files) > 2: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + torch.cuda.empty_cache() + +async def main(fp8_path, bf16_path): """ Converts FP8 weights to BF16 and saves the converted weights. @@ -37,8 +66,7 @@ def main(fp8_path, bf16_path): weight_map = model_index["weight_map"] # Cache for loaded safetensor files - loaded_files = {} - fp8_weight_names = [] + loaded_files, fp8_weight_names = {}, [] # Helper function to get tensor from the correct file def get_tensor(tensor_name): @@ -62,45 +90,15 @@ def main(fp8_path, bf16_path): safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files.sort() - for safetensor_file in tqdm(safetensor_files): - file_name = os.path.basename(safetensor_file) - current_state_dict = load_file(safetensor_file, device="cuda") - loaded_files[file_name] = current_state_dict - - new_state_dict = {} - for weight_name, weight in current_state_dict.items(): - if weight_name.endswith("_scale_inv"): - continue - elif weight.element_size() == 1: # FP8 weight - scale_inv_name = f"{weight_name}_scale_inv" - try: - # Get scale_inv from the correct file - scale_inv = get_tensor(scale_inv_name) - fp8_weight_names.append(weight_name) - new_state_dict[weight_name] = weight_dequant(weight, scale_inv) - except KeyError: - print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") - new_state_dict[weight_name] = weight - else: - new_state_dict[weight_name] = weight - - new_safetensor_file = os.path.join(bf16_path, file_name) - save_file(new_state_dict, new_safetensor_file) - - # Memory management: keep only the 2 most recently used files - if len(loaded_files) > 2: - oldest_file = next(iter(loaded_files)) - del loaded_files[oldest_file] - torch.cuda.empty_cache() + gather(*(to_thread(inner_tensor_file, safetensor_file) for safetensor_file in tqdm(safetensor_files))) + # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") for weight_name in fp8_weight_names: scale_inv_name = f"{weight_name}_scale_inv" - if scale_inv_name in weight_map: - weight_map.pop(scale_inv_name) - with open(new_model_index_file, "w") as f: - json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) + if scale_inv_name in weight_map: weight_map.pop(scale_inv_name) + with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) if __name__ == "__main__": @@ -108,5 +106,5 @@ if __name__ == "__main__": parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True) args = parser.parse_args() - main(args.input_fp8_hf_path, args.output_bf16_hf_path) + run(main(args.input_fp8_hf_path, args.output_bf16_hf_path)) diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..01f2cb3 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -1,13 +1,13 @@ import os import json +from parser import Parser from argparse import ArgumentParser from typing import List - import torch import torch.distributed as dist from transformers import AutoTokenizer from safetensors.torch import load_model - +from asyncio import gather, to_thread, run from model import Transformer, ModelArgs @@ -36,6 +36,7 @@ def generate( temperature: float = 1.0 ) -> List[List[int]]: """ + Generates new tokens based on the given prompt tokens using the specified model. Args: @@ -47,38 +48,35 @@ def generate( Returns: List[List[int]]: A list of lists containing the generated tokens for each sequence. + """ prompt_lens = [len(t) for t in prompt_tokens] assert max(prompt_lens) <= model.max_seq_len total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") - for i, t in enumerate(prompt_tokens): - tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + for i, t in enumerate(prompt_tokens): tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") prev_pos = 0 finished = torch.tensor([False] * len(prompt_tokens), device="cuda") prompt_mask = tokens != -1 - for cur_pos in range(min(prompt_lens), total_len): + def inner_cur_pos(): logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - if temperature > 0: - next_token = sample(logits, temperature) - else: - next_token = logits.argmax(dim=-1) + if temperature > 0: next_token = sample(logits, temperature) + else: next_token = logits.argmax(dim=-1) next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) prev_pos = cur_pos - if finished.all(): - break + if finished.all(): return + gather(*(to_thread(cur_pos) for cur_pos in range(min(prompt_lens), total_len))) completion_tokens = [] for i, toks in enumerate(tokens.tolist()): toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] - if eos_id in toks: - toks = toks[:toks.index(eos_id)] + if eos_id in toks: toks = toks[:toks.index(eos_id)] completion_tokens.append(toks) return completion_tokens -def main( +async def main( ckpt_path: str, config: str, input_file: str = "", @@ -131,8 +129,7 @@ def main( objects = [None] dist.broadcast_object_list(objects, 0) prompt = objects[0] - if prompt == "/exit": - break + if prompt == "/exit": break elif prompt == "/clear": messages.clear() continue @@ -143,8 +140,7 @@ def main( print(completion) messages.append({"role": "assistant", "content": completion}) else: - with open(input_file) as f: - prompts = [line.strip() for line in f.readlines()] + with open(input_file) as f: prompts = [line.strip() for line in f.readlines()] assert len(prompts) <= args.max_batch_size prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) @@ -154,8 +150,7 @@ def main( print("Completion:", completion) print() - if world_size > 1: - dist.destroy_process_group() + if world_size > 1: dist.destroy_process_group() if __name__ == "__main__": @@ -173,13 +168,14 @@ if __name__ == "__main__": Raises: AssertionError: If neither input-file nor interactive mode is specified. """ - parser = ArgumentParser() - parser.add_argument("--ckpt-path", type=str, required=True) - parser.add_argument("--config", type=str, required=True) - parser.add_argument("--input-file", type=str, default="") - parser.add_argument("--interactive", action="store_true") - parser.add_argument("--max-new-tokens", type=int, default=200) - parser.add_argument("--temperature", type=float, default=0.2) - args = parser.parse_args() + arg_variables = [ + ("--ckpt-path", type:=str, required:=True), + ("--config", type:=str, required:=True), + ("--input-file", type:=str, default:=""), + ("--interactive", action:="store_true"), + ("--max-new-tokens", type:=int, default:=200), + ("--temperature", type:=float, default:=0.2) + ] + args = Parser(arg_list=arg_variables).apply_args().return_args() assert args.input_file or args.interactive - main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) + run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)) diff --git a/inference/parser.py b/inference/parser.py new file mode 100644 index 0000000..847ce1a --- /dev/null +++ b/inference/parser.py @@ -0,0 +1,12 @@ +from argparse import ArgumentParser + +class Parser(): + def __init__(self, parser = ArgumentParser(), arg_list = []): + self.parser = parser + self.arg_list = arg_list + def apply_args(self): + for arg in self.arg_list: self.parser.add_argument(*arg) + return self + def return_args(self): + return self.parser.parse_args() + \ No newline at end of file From c8146ec360514b893b1608643aa30351b5cbcb96 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Mon, 3 Feb 2025 10:18:58 -0700 Subject: [PATCH 8/9] small patch applied further Parser class --- inference/convert.py | 15 ++++++++------- inference/fp8_cast_bf16.py | 4 +--- inference/generate.py | 4 ++-- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 1073ffb..e83ba9a 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -1,6 +1,6 @@ import os import shutil -from argparse import ArgumentParser +from parser import Parser from glob import glob from tqdm import tqdm, trange import asyncio as sync @@ -89,11 +89,12 @@ def main(hf_ckpt_path, save_path, n_experts, mp): if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--hf-ckpt-path", type=str, required=True) - parser.add_argument("--save-path", type=str, required=True) - parser.add_argument("--n-experts", type=int, required=True) - parser.add_argument("--model-parallel", type=int, required=True) - args = parser.parse_args() + arg_list = [ + ("--hf-ckpt-path", type:=str, required:=True), + ("--save-path", type:=str, required:=True), + ("--n-experts", type:=int, required:=True), + ("--model-parallel", type:=int, required:=True) + ] + args = Parser(arg_list).apply_args().return_args() assert args.n_experts % args.model_parallel == 0 sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 3047edc..2dec503 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -6,7 +6,6 @@ from tqdm import tqdm from asyncio import gather, to_thread, run import torch from safetensors.torch import load_file, save_file - from kernel import weight_dequant def inner_tensor_file(safetensor_file): @@ -61,8 +60,7 @@ async def main(fp8_path, bf16_path): torch.set_default_dtype(torch.bfloat16) os.makedirs(bf16_path, exist_ok=True) model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") - with open(model_index_file, "r") as f: - model_index = json.load(f) + with open(model_index_file, "r") as f: model_index = json.load(f) weight_map = model_index["weight_map"] # Cache for loaded safetensor files diff --git a/inference/generate.py b/inference/generate.py index 01f2cb3..c67599c 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -168,7 +168,7 @@ if __name__ == "__main__": Raises: AssertionError: If neither input-file nor interactive mode is specified. """ - arg_variables = [ + arg_list = [ ("--ckpt-path", type:=str, required:=True), ("--config", type:=str, required:=True), ("--input-file", type:=str, default:=""), @@ -176,6 +176,6 @@ if __name__ == "__main__": ("--max-new-tokens", type:=int, default:=200), ("--temperature", type:=float, default:=0.2) ] - args = Parser(arg_list=arg_variables).apply_args().return_args() + args = Parser(arg_list).apply_args().return_args() assert args.input_file or args.interactive run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)) From 77c46698b9bd6c73c6371dab8ada0c2ad95f6369 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Mon, 3 Feb 2025 10:24:23 -0700 Subject: [PATCH 9/9] added assert to Parser class Parser now can apply assert to self --- inference/convert.py | 3 +-- inference/generate.py | 3 +-- inference/parser.py | 6 ++++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index e83ba9a..017d4cc 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -95,6 +95,5 @@ if __name__ == "__main__": ("--n-experts", type:=int, required:=True), ("--model-parallel", type:=int, required:=True) ] - args = Parser(arg_list).apply_args().return_args() - assert args.n_experts % args.model_parallel == 0 + args = Parser(arg_list).apply_args().assert_model_parallel().return_args() sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)) diff --git a/inference/generate.py b/inference/generate.py index c67599c..80dbd30 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -176,6 +176,5 @@ if __name__ == "__main__": ("--max-new-tokens", type:=int, default:=200), ("--temperature", type:=float, default:=0.2) ] - args = Parser(arg_list).apply_args().return_args() - assert args.input_file or args.interactive + args = Parser(arg_list).apply_args().assert_interactive().return_args() run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)) diff --git a/inference/parser.py b/inference/parser.py index 847ce1a..12fd64e 100644 --- a/inference/parser.py +++ b/inference/parser.py @@ -7,6 +7,12 @@ class Parser(): def apply_args(self): for arg in self.arg_list: self.parser.add_argument(*arg) return self + def assert_model_parallel(self): + assert self.return_args.n_experts % self.return_args().model_parallel == 0 + return self + def assert_interactive(): + assert self.return_args().input_file or self.return_args().interactive + return self def return_args(self): return self.parser.parse_args() \ No newline at end of file