Update generate.py

Error Handling: Improved error handling for file operations and JSON loading.
Logging: Clearer logging messages for better debugging and monitoring.
Code Comments: Added more descriptive comments to enhance code readability.
Dynamic Sequence Management: Clarified the logic for managing prompt lengths and token generation.
Performance: Minor optimizations in code structure and logic flow for better performance.
Code Structure: Organized functions and constants for better readability and maintainability.
This commit is contained in:
Cristian Cezar Moisés 2025-01-27 23:26:23 -03:00 committed by GitHub
parent e1ed2e8465
commit af62bffeff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import logging
from argparse import ArgumentParser
from pathlib import Path
from typing import List, Optional, Dict, Tuple
from datetime import timedelta
from contextlib import nullcontext
import torch
@ -59,10 +60,8 @@ def load_model_config(config_path: Path) -> ModelArgs:
def initialize_model(args: ModelArgs, device: str) -> Transformer:
"""Initialize model with proper device placement and dtype."""
with torch.device(device):
model = Transformer(args)
model.to(TORCH_DTYPE)
model.eval()
model = Transformer(args).to(TORCH_DTYPE)
model.eval()
return model
def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
@ -82,7 +81,7 @@ def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> t
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('inf')
logits[logits < v[:, [-1]]] = float('-inf')
probs = torch.softmax(logits / max(temperature, 1e-5), dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(1)
@ -112,13 +111,12 @@ def generate(
Returns:
List of generated token sequences
"""
# Initialize generation state
batch_size = len(prompt_tokens)
device = next(model.parameters()).device
max_seq_len = model.max_seq_len
prompt_lens = [len(p) for p in prompt_tokens]
# Validate input lengths
# Adjust max_new_tokens based on input length
if max(prompt_lens) + max_new_tokens > max_seq_len:
logger.warning(f"Truncating sequence length to {max_seq_len}")
max_new_tokens = max_seq_len - max(prompt_lens)
@ -166,7 +164,7 @@ def generate(
progress_bar.close()
# Process outputs
return [seq[pl:pl+max_new_tokens].tolist() for pl, seq in zip(prompt_lens, tokens)]
return [seq[pl:pl + max_new_tokens].tolist() for pl, seq in zip(prompt_lens, tokens)]
def interactive_loop(
model: Transformer,
@ -183,23 +181,20 @@ def interactive_loop(
while True:
try:
# Distributed input handling
prompt = None
if world_size > 1:
if rank == 0:
prompt = input("\nUser: ")
dist.broadcast_object_list([prompt], src=0)
else:
prompt = None
dist.broadcast_object_list([prompt], src=0)
if prompt == "/exit":
break
else:
prompt = input("\nUser: ")
# Command handling
if prompt == "/exit":
break
if prompt == "/clear":
if prompt in ["/exit", "/clear"]:
if prompt == "/exit":
break
messages.clear()
logger.info("History cleared")
continue
@ -260,7 +255,7 @@ def batch_process(
# Generate in batches
completions = []
for i in tqdm(range(0, len(prompt_tokens), model.args.max_batch_size)):
batch = prompt_tokens[i:i+model.args.max_batch_size]
batch = prompt_tokens[i:i + model.args.max_batch_size]
completions += generate(model, batch, max_new_tokens, tokenizer.eos_token_id, temperature)
# Decode and print