From e965eec9c045f7e61fa50590f8151f3dd3995367 Mon Sep 17 00:00:00 2001 From: Anand Date: Wed, 29 Jan 2025 22:32:59 +0530 Subject: [PATCH] =?UTF-8?q?=1B[200~feat:=20Added=20logging,=20parallel=20p?= =?UTF-8?q?rocessing,=20and=20CPU=20processing=20option=20for=20FP8=20to?= =?UTF-8?q?=20BF16=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- inference/fp8_cast_bf16.py | 45 +++++++++---- inference/generate.py | 131 +++++++++++++++++++++++++------------ 2 files changed, 121 insertions(+), 55 deletions(-) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..181e999 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -3,32 +3,45 @@ import json from argparse import ArgumentParser from glob import glob from tqdm import tqdm +import logging +from concurrent.futures import ThreadPoolExecutor import torch from safetensors.torch import load_file, save_file from kernel import weight_dequant -def main(fp8_path, bf16_path): +def setup_logging(): + logging.basicConfig( + filename="conversion.log", + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s" + ) + +def main(fp8_path, bf16_path, use_cpu): """ Converts FP8 weights to BF16 and saves the converted weights. - + This function reads FP8 weights from the specified directory, converts them to BF16, and saves the converted weights to another specified directory. It also updates the model index file to reflect the changes. - + Args: fp8_path (str): The path to the directory containing the FP8 weights and model index file. bf16_path (str): The path to the directory where the converted BF16 weights will be saved. - + use_cpu (bool): Whether to use CPU instead of GPU. + Raises: KeyError: If a required scale_inv tensor is missing for a weight. - + Notes: - The function assumes that the FP8 weights are stored in safetensor files. - The function caches loaded safetensor files to optimize memory usage. - The function updates the model index file to remove references to scale_inv tensors. """ + setup_logging() + + device = "cpu" if use_cpu else "cuda" torch.set_default_dtype(torch.bfloat16) os.makedirs(bf16_path, exist_ok=True) model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") @@ -39,7 +52,7 @@ def main(fp8_path, bf16_path): # Cache for loaded safetensor files loaded_files = {} fp8_weight_names = [] - + # Helper function to get tensor from the correct file def get_tensor(tensor_name): """ @@ -57,18 +70,19 @@ def main(fp8_path, bf16_path): file_name = weight_map[tensor_name] if file_name not in loaded_files: file_path = os.path.join(fp8_path, file_name) - loaded_files[file_name] = load_file(file_path, device="cuda") + loaded_files[file_name] = load_file(file_path, device=device) return loaded_files[file_name][tensor_name] - + safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files.sort() - for safetensor_file in tqdm(safetensor_files): + + def process_file(safetensor_file): file_name = os.path.basename(safetensor_file) - current_state_dict = load_file(safetensor_file, device="cuda") + current_state_dict = load_file(safetensor_file, device=device) loaded_files[file_name] = current_state_dict new_state_dict = {} - for weight_name, weight in current_state_dict.items(): + for weight_name, weight in tqdm(current_state_dict.items(), desc=f"Processing {file_name}"): if weight_name.endswith("_scale_inv"): continue elif weight.element_size() == 1: # FP8 weight @@ -79,7 +93,7 @@ def main(fp8_path, bf16_path): fp8_weight_names.append(weight_name) new_state_dict[weight_name] = weight_dequant(weight, scale_inv) except KeyError: - print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") + logging.warning(f"Missing scale_inv tensor for {weight_name}, skipping conversion") new_state_dict[weight_name] = weight else: new_state_dict[weight_name] = weight @@ -93,6 +107,9 @@ def main(fp8_path, bf16_path): del loaded_files[oldest_file] torch.cuda.empty_cache() + with ThreadPoolExecutor() as executor: + list(tqdm(executor.map(process_file, safetensor_files), total=len(safetensor_files), desc="Converting files")) + # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") for weight_name in fp8_weight_names: @@ -102,11 +119,13 @@ def main(fp8_path, bf16_path): with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) + logging.info("Conversion completed successfully.") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True) + parser.add_argument("--use-cpu", action="store_true", help="Use CPU for processing instead of GPU") args = parser.parse_args() - main(args.input_fp8_hf_path, args.output_bf16_hf_path) + main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.use_cpu) diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..bb90b72 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -1,5 +1,6 @@ import os import json +import logging from argparse import ArgumentParser from typing import List @@ -10,6 +11,9 @@ from safetensors.torch import load_model from model import Transformer, ModelArgs +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) def sample(logits, temperature: float = 1.0): """ @@ -49,32 +53,37 @@ def generate( List[List[int]]: A list of lists containing the generated tokens for each sequence. """ prompt_lens = [len(t) for t in prompt_tokens] - assert max(prompt_lens) <= model.max_seq_len + assert max(prompt_lens) <= model.max_seq_len, "Prompt length exceeds model max sequence length" + total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") + for i, t in enumerate(prompt_tokens): tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + prev_pos = 0 finished = torch.tensor([False] * len(prompt_tokens), device="cuda") prompt_mask = tokens != -1 + for cur_pos in range(min(prompt_lens), total_len): logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) - if temperature > 0: - next_token = sample(logits, temperature) - else: - next_token = logits.argmax(dim=-1) + next_token = sample(logits, temperature) if temperature > 0 else logits.argmax(dim=-1) + next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) + prev_pos = cur_pos if finished.all(): break + completion_tokens = [] for i, toks in enumerate(tokens.tolist()): - toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] + toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens] if eos_id in toks: toks = toks[:toks.index(eos_id)] completion_tokens.append(toks) + return completion_tokens @@ -100,60 +109,96 @@ def main( world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0")) + if world_size > 1: dist.init_process_group("nccl") - global print + if rank != 0: - print = lambda *_, **__: None + logger.setLevel(logging.WARNING) + torch.cuda.set_device(local_rank) torch.set_default_dtype(torch.bfloat16) torch.set_num_threads(8) torch.manual_seed(965) - with open(config) as f: - args = ModelArgs(**json.load(f)) - print(args) + + # Load model args + try: + with open(config) as f: + args = ModelArgs(**json.load(f)) + except FileNotFoundError as e: + logger.error(f"Config file not found: {e}") + return + except json.JSONDecodeError as e: + logger.error(f"Error parsing config file: {e}") + return + + logger.info(f"Model args: {args}") + + # Load the model on GPU with torch.device("cuda"): model = Transformer(args) + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) - tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) - load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) - + + # Generate a test sequence to verify everything is working + test_prompt = "DeepSeek" + test_tokens = tokenizer.encode(test_prompt) + generated_tokens = generate(model, [test_tokens], 2, tokenizer.eos_token_id, 1.0) + logger.info(f"Generated test output: {tokenizer.decode(generated_tokens[0])}") + + # Load model weights + try: + load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) + except Exception as e: + logger.error(f"Error loading model: {e}") + return + + # Interactive mode or batch processing if interactive: messages = [] while True: - if world_size == 1: + if world_size == 1 or rank == 0: prompt = input(">>> ") - elif rank == 0: - prompt = input(">>> ") - objects = [prompt] - dist.broadcast_object_list(objects, 0) - else: + if prompt == "/exit": + break + elif prompt == "/clear": + messages.clear() + continue + messages.append({"role": "user", "content": prompt}) + prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) + completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) + logger.info(f"Generated completion: {completion}") + messages.append({"role": "assistant", "content": completion}) + elif rank != 0: + # Synchronize input across multiple nodes objects = [None] dist.broadcast_object_list(objects, 0) prompt = objects[0] - if prompt == "/exit": - break - elif prompt == "/clear": - messages.clear() - continue - messages.append({"role": "user", "content": prompt}) - prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) - completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) - print(completion) - messages.append({"role": "assistant", "content": completion}) + else: - with open(input_file) as f: - prompts = [line.strip() for line in f.readlines()] - assert len(prompts) <= args.max_batch_size + # Batch processing mode + if not input_file: + logger.error("Input file is required for batch processing mode") + return + try: + with open(input_file) as f: + prompts = [line.strip() for line in f.readlines()] + except FileNotFoundError as e: + logger.error(f"Input file not found: {e}") + return + + assert len(prompts) <= args.max_batch_size, "Exceeds batch size limit" + prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + for prompt, completion in zip(prompts, completions): print("Prompt:", prompt) print("Completion:", completion) print() - + if world_size > 1: dist.destroy_process_group() @@ -174,12 +219,14 @@ if __name__ == "__main__": AssertionError: If neither input-file nor interactive mode is specified. """ parser = ArgumentParser() - parser.add_argument("--ckpt-path", type=str, required=True) - parser.add_argument("--config", type=str, required=True) - parser.add_argument("--input-file", type=str, default="") - parser.add_argument("--interactive", action="store_true") - parser.add_argument("--max-new-tokens", type=int, default=200) - parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--ckpt-path", type=str, required=True, help="Path to the model checkpoint directory.") + parser.add_argument("--config", type=str, required=True, help="Path to the model configuration file.") + parser.add_argument("--input-file", type=str, default="", help="File containing prompts for batch processing.") + parser.add_argument("--interactive", action="store_true", help="Enable interactive mode.") + parser.add_argument("--max-new-tokens", type=int, default=200, help="Maximum number of new tokens to generate.") + parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.") + args = parser.parse_args() - assert args.input_file or args.interactive + assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" + main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)