diff --git a/inference/convert.py b/inference/convert.py index 1fc1ce0..1073ffb 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -70,10 +70,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp): None """ torch.set_num_threads(8) - n_local_experts = n_experts // mp - state_dicts = [{} for _ in range(mp)] - tensor_dir = glob(os.path.join(hf_ckpt_path, "*.safetensors")) - token_dir = glob(os.path.join(hf_ckpt_path, "*token*")) + n_local_experts,state_dicts = n_experts // mp, [{} for _ in range(mp)] + tensor_dir, token_dir = list(glob(os.path.join(hf_ckpt_path, "*.safetensors"))),list(glob(os.path.join(hf_ckpt_path, "*token*"))) for file_path in tqdm(tensor_dir): cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu") async with cm as f: diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..3047edc 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -3,13 +3,42 @@ import json from argparse import ArgumentParser from glob import glob from tqdm import tqdm - +from asyncio import gather, to_thread, run import torch from safetensors.torch import load_file, save_file from kernel import weight_dequant -def main(fp8_path, bf16_path): +def inner_tensor_file(safetensor_file): + file_name = os.path.basename(safetensor_file) + current_state_dict = load_file(safetensor_file, device="cuda") + loaded_files[file_name] = current_state_dict + new_state_dict = {} + for weight_name, weight in current_state_dict.items(): + if weight_name.endswith("_scale_inv"): + continue + elif weight.element_size() == 1: # FP8 weight + scale_inv_name = f"{weight_name}_scale_inv" + try: + # Get scale_inv from the correct file + scale_inv = get_tensor(scale_inv_name) + fp8_weight_names.append(weight_name) + new_state_dict[weight_name] = weight_dequant(weight, scale_inv) + except KeyError: + print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") + new_state_dict[weight_name] = weight + else: + new_state_dict[weight_name] = weight + new_safetensor_file = os.path.join(bf16_path, file_name) + save_file(new_state_dict, new_safetensor_file) + + # Memory management: keep only the 2 most recently used files + if len(loaded_files) > 2: + oldest_file = next(iter(loaded_files)) + del loaded_files[oldest_file] + torch.cuda.empty_cache() + +async def main(fp8_path, bf16_path): """ Converts FP8 weights to BF16 and saves the converted weights. @@ -37,8 +66,7 @@ def main(fp8_path, bf16_path): weight_map = model_index["weight_map"] # Cache for loaded safetensor files - loaded_files = {} - fp8_weight_names = [] + loaded_files, fp8_weight_names = {}, [] # Helper function to get tensor from the correct file def get_tensor(tensor_name): @@ -62,45 +90,15 @@ def main(fp8_path, bf16_path): safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files.sort() - for safetensor_file in tqdm(safetensor_files): - file_name = os.path.basename(safetensor_file) - current_state_dict = load_file(safetensor_file, device="cuda") - loaded_files[file_name] = current_state_dict - - new_state_dict = {} - for weight_name, weight in current_state_dict.items(): - if weight_name.endswith("_scale_inv"): - continue - elif weight.element_size() == 1: # FP8 weight - scale_inv_name = f"{weight_name}_scale_inv" - try: - # Get scale_inv from the correct file - scale_inv = get_tensor(scale_inv_name) - fp8_weight_names.append(weight_name) - new_state_dict[weight_name] = weight_dequant(weight, scale_inv) - except KeyError: - print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") - new_state_dict[weight_name] = weight - else: - new_state_dict[weight_name] = weight - - new_safetensor_file = os.path.join(bf16_path, file_name) - save_file(new_state_dict, new_safetensor_file) - - # Memory management: keep only the 2 most recently used files - if len(loaded_files) > 2: - oldest_file = next(iter(loaded_files)) - del loaded_files[oldest_file] - torch.cuda.empty_cache() + gather(*(to_thread(inner_tensor_file, safetensor_file) for safetensor_file in tqdm(safetensor_files))) + # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") for weight_name in fp8_weight_names: scale_inv_name = f"{weight_name}_scale_inv" - if scale_inv_name in weight_map: - weight_map.pop(scale_inv_name) - with open(new_model_index_file, "w") as f: - json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) + if scale_inv_name in weight_map: weight_map.pop(scale_inv_name) + with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) if __name__ == "__main__": @@ -108,5 +106,5 @@ if __name__ == "__main__": parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True) args = parser.parse_args() - main(args.input_fp8_hf_path, args.output_bf16_hf_path) + run(main(args.input_fp8_hf_path, args.output_bf16_hf_path)) diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..01f2cb3 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -1,13 +1,13 @@ import os import json +from parser import Parser from argparse import ArgumentParser from typing import List - import torch import torch.distributed as dist from transformers import AutoTokenizer from safetensors.torch import load_model - +from asyncio import gather, to_thread, run from model import Transformer, ModelArgs @@ -36,6 +36,7 @@ def generate( temperature: float = 1.0 ) -> List[List[int]]: """ + Generates new tokens based on the given prompt tokens using the specified model. Args: @@ -47,38 +48,35 @@ def generate( Returns: List[List[int]]: A list of lists containing the generated tokens for each sequence. + """ prompt_lens = [len(t) for t in prompt_tokens] assert max(prompt_lens) <= model.max_seq_len total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") - for i, t in enumerate(prompt_tokens): - tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + for i, t in enumerate(prompt_tokens): tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") prev_pos = 0 finished = torch.tensor([False] * len(prompt_tokens), device="cuda") prompt_mask = tokens != -1 - for cur_pos in range(min(prompt_lens), total_len): + def inner_cur_pos(): logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - if temperature > 0: - next_token = sample(logits, temperature) - else: - next_token = logits.argmax(dim=-1) + if temperature > 0: next_token = sample(logits, temperature) + else: next_token = logits.argmax(dim=-1) next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) prev_pos = cur_pos - if finished.all(): - break + if finished.all(): return + gather(*(to_thread(cur_pos) for cur_pos in range(min(prompt_lens), total_len))) completion_tokens = [] for i, toks in enumerate(tokens.tolist()): toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] - if eos_id in toks: - toks = toks[:toks.index(eos_id)] + if eos_id in toks: toks = toks[:toks.index(eos_id)] completion_tokens.append(toks) return completion_tokens -def main( +async def main( ckpt_path: str, config: str, input_file: str = "", @@ -131,8 +129,7 @@ def main( objects = [None] dist.broadcast_object_list(objects, 0) prompt = objects[0] - if prompt == "/exit": - break + if prompt == "/exit": break elif prompt == "/clear": messages.clear() continue @@ -143,8 +140,7 @@ def main( print(completion) messages.append({"role": "assistant", "content": completion}) else: - with open(input_file) as f: - prompts = [line.strip() for line in f.readlines()] + with open(input_file) as f: prompts = [line.strip() for line in f.readlines()] assert len(prompts) <= args.max_batch_size prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) @@ -154,8 +150,7 @@ def main( print("Completion:", completion) print() - if world_size > 1: - dist.destroy_process_group() + if world_size > 1: dist.destroy_process_group() if __name__ == "__main__": @@ -173,13 +168,14 @@ if __name__ == "__main__": Raises: AssertionError: If neither input-file nor interactive mode is specified. """ - parser = ArgumentParser() - parser.add_argument("--ckpt-path", type=str, required=True) - parser.add_argument("--config", type=str, required=True) - parser.add_argument("--input-file", type=str, default="") - parser.add_argument("--interactive", action="store_true") - parser.add_argument("--max-new-tokens", type=int, default=200) - parser.add_argument("--temperature", type=float, default=0.2) - args = parser.parse_args() + arg_variables = [ + ("--ckpt-path", type:=str, required:=True), + ("--config", type:=str, required:=True), + ("--input-file", type:=str, default:=""), + ("--interactive", action:="store_true"), + ("--max-new-tokens", type:=int, default:=200), + ("--temperature", type:=float, default:=0.2) + ] + args = Parser(arg_list=arg_variables).apply_args().return_args() assert args.input_file or args.interactive - main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) + run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)) diff --git a/inference/parser.py b/inference/parser.py new file mode 100644 index 0000000..847ce1a --- /dev/null +++ b/inference/parser.py @@ -0,0 +1,12 @@ +from argparse import ArgumentParser + +class Parser(): + def __init__(self, parser = ArgumentParser(), arg_list = []): + self.parser = parser + self.arg_list = arg_list + def apply_args(self): + for arg in self.arg_list: self.parser.add_argument(*arg) + return self + def return_args(self): + return self.parser.parse_args() + \ No newline at end of file