Update generate.py

Error Handling: Improved error handling for file operations and JSON loading.
Logging: Clearer logging messages for better debugging and monitoring.
Code Comments: Added more descriptive comments to enhance code readability.
Dynamic Sequence Management: Clarified the logic for managing prompt lengths and token generation.
Performance: Minor optimizations in code structure and logic flow for better performance.
Code Structure: Organized functions and constants for better readability and maintainability.
This commit is contained in:
Cristian Cezar Moisés 2025-01-27 23:26:23 -03:00 committed by GitHub
parent e1ed2e8465
commit af62bffeff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Tuple from typing import List, Optional, Dict, Tuple
from datetime import timedelta
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
@ -59,10 +60,8 @@ def load_model_config(config_path: Path) -> ModelArgs:
def initialize_model(args: ModelArgs, device: str) -> Transformer: def initialize_model(args: ModelArgs, device: str) -> Transformer:
"""Initialize model with proper device placement and dtype.""" """Initialize model with proper device placement and dtype."""
with torch.device(device): model = Transformer(args).to(TORCH_DTYPE)
model = Transformer(args) model.eval()
model.to(TORCH_DTYPE)
model.eval()
return model return model
def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> torch.Tensor: def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
@ -82,7 +81,7 @@ def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> t
if top_k > 0: if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('inf') logits[logits < v[:, [-1]]] = float('-inf')
probs = torch.softmax(logits / max(temperature, 1e-5), dim=-1) probs = torch.softmax(logits / max(temperature, 1e-5), dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(1) return torch.multinomial(probs, num_samples=1).squeeze(1)
@ -112,13 +111,12 @@ def generate(
Returns: Returns:
List of generated token sequences List of generated token sequences
""" """
# Initialize generation state
batch_size = len(prompt_tokens) batch_size = len(prompt_tokens)
device = next(model.parameters()).device device = next(model.parameters()).device
max_seq_len = model.max_seq_len max_seq_len = model.max_seq_len
prompt_lens = [len(p) for p in prompt_tokens] prompt_lens = [len(p) for p in prompt_tokens]
# Validate input lengths # Adjust max_new_tokens based on input length
if max(prompt_lens) + max_new_tokens > max_seq_len: if max(prompt_lens) + max_new_tokens > max_seq_len:
logger.warning(f"Truncating sequence length to {max_seq_len}") logger.warning(f"Truncating sequence length to {max_seq_len}")
max_new_tokens = max_seq_len - max(prompt_lens) max_new_tokens = max_seq_len - max(prompt_lens)
@ -166,7 +164,7 @@ def generate(
progress_bar.close() progress_bar.close()
# Process outputs # Process outputs
return [seq[pl:pl+max_new_tokens].tolist() for pl, seq in zip(prompt_lens, tokens)] return [seq[pl:pl + max_new_tokens].tolist() for pl, seq in zip(prompt_lens, tokens)]
def interactive_loop( def interactive_loop(
model: Transformer, model: Transformer,
@ -183,23 +181,20 @@ def interactive_loop(
while True: while True:
try: try:
# Distributed input handling # Distributed input handling
prompt = None
if world_size > 1: if world_size > 1:
if rank == 0: if rank == 0:
prompt = input("\nUser: ") prompt = input("\nUser: ")
dist.broadcast_object_list([prompt], src=0) dist.broadcast_object_list([prompt], src=0)
else: else:
prompt = None
dist.broadcast_object_list([prompt], src=0) dist.broadcast_object_list([prompt], src=0)
if prompt == "/exit":
break
else: else:
prompt = input("\nUser: ") prompt = input("\nUser: ")
# Command handling # Command handling
if prompt == "/exit": if prompt in ["/exit", "/clear"]:
break if prompt == "/exit":
if prompt == "/clear": break
messages.clear() messages.clear()
logger.info("History cleared") logger.info("History cleared")
continue continue
@ -260,7 +255,7 @@ def batch_process(
# Generate in batches # Generate in batches
completions = [] completions = []
for i in tqdm(range(0, len(prompt_tokens), model.args.max_batch_size)): for i in tqdm(range(0, len(prompt_tokens), model.args.max_batch_size)):
batch = prompt_tokens[i:i+model.args.max_batch_size] batch = prompt_tokens[i:i + model.args.max_batch_size]
completions += generate(model, batch, max_new_tokens, tokenizer.eos_token_id, temperature) completions += generate(model, batch, max_new_tokens, tokenizer.eos_token_id, temperature)
# Decode and print # Decode and print