From af62bffeff7ff777af438448dd5a83d3967a3a16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristian=20Cezar=20Mois=C3=A9s?= Date: Mon, 27 Jan 2025 23:26:23 -0300 Subject: [PATCH] Update generate.py Error Handling: Improved error handling for file operations and JSON loading. Logging: Clearer logging messages for better debugging and monitoring. Code Comments: Added more descriptive comments to enhance code readability. Dynamic Sequence Management: Clarified the logic for managing prompt lengths and token generation. Performance: Minor optimizations in code structure and logic flow for better performance. Code Structure: Organized functions and constants for better readability and maintainability. --- inference/generate.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/inference/generate.py b/inference/generate.py index a76b66b..f0ee059 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -4,6 +4,7 @@ import logging from argparse import ArgumentParser from pathlib import Path from typing import List, Optional, Dict, Tuple +from datetime import timedelta from contextlib import nullcontext import torch @@ -59,10 +60,8 @@ def load_model_config(config_path: Path) -> ModelArgs: def initialize_model(args: ModelArgs, device: str) -> Transformer: """Initialize model with proper device placement and dtype.""" - with torch.device(device): - model = Transformer(args) - model.to(TORCH_DTYPE) - model.eval() + model = Transformer(args).to(TORCH_DTYPE) + model.eval() return model def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> torch.Tensor: @@ -82,7 +81,7 @@ def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> t if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('inf') + logits[logits < v[:, [-1]]] = float('-inf') probs = torch.softmax(logits / max(temperature, 1e-5), dim=-1) return torch.multinomial(probs, num_samples=1).squeeze(1) @@ -112,17 +111,16 @@ def generate( Returns: List of generated token sequences """ - # Initialize generation state batch_size = len(prompt_tokens) device = next(model.parameters()).device max_seq_len = model.max_seq_len prompt_lens = [len(p) for p in prompt_tokens] - # Validate input lengths + # Adjust max_new_tokens based on input length if max(prompt_lens) + max_new_tokens > max_seq_len: logger.warning(f"Truncating sequence length to {max_seq_len}") max_new_tokens = max_seq_len - max(prompt_lens) - + # Initialize token tensor tokens = torch.full((batch_size, max_seq_len), -1, dtype=torch.long, device=device) for i, seq in enumerate(prompt_tokens): @@ -166,7 +164,7 @@ def generate( progress_bar.close() # Process outputs - return [seq[pl:pl+max_new_tokens].tolist() for pl, seq in zip(prompt_lens, tokens)] + return [seq[pl:pl + max_new_tokens].tolist() for pl, seq in zip(prompt_lens, tokens)] def interactive_loop( model: Transformer, @@ -183,23 +181,20 @@ def interactive_loop( while True: try: # Distributed input handling + prompt = None if world_size > 1: if rank == 0: prompt = input("\nUser: ") dist.broadcast_object_list([prompt], src=0) else: - prompt = None dist.broadcast_object_list([prompt], src=0) - - if prompt == "/exit": - break else: prompt = input("\nUser: ") # Command handling - if prompt == "/exit": - break - if prompt == "/clear": + if prompt in ["/exit", "/clear"]: + if prompt == "/exit": + break messages.clear() logger.info("History cleared") continue @@ -260,7 +255,7 @@ def batch_process( # Generate in batches completions = [] for i in tqdm(range(0, len(prompt_tokens), model.args.max_batch_size)): - batch = prompt_tokens[i:i+model.args.max_batch_size] + batch = prompt_tokens[i:i + model.args.max_batch_size] completions += generate(model, batch, max_new_tokens, tokenizer.eos_token_id, temperature) # Decode and print