Update generate.py

Error Handling: Improved error handling for file operations and JSON loading.
Logging: Clearer logging messages for better debugging and monitoring.
Code Comments: Added more descriptive comments to enhance code readability.
Dynamic Sequence Management: Clarified the logic for managing prompt lengths and token generation.
Performance: Minor optimizations in code structure and logic flow for better performance.
Code Structure: Organized functions and constants for better readability and maintainability.
This commit is contained in:
Cristian Cezar Moisés 2025-01-27 23:26:23 -03:00 committed by GitHub
parent e1ed2e8465
commit af62bffeff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Tuple from typing import List, Optional, Dict, Tuple
from datetime import timedelta
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
@ -59,9 +60,7 @@ def load_model_config(config_path: Path) -> ModelArgs:
def initialize_model(args: ModelArgs, device: str) -> Transformer: def initialize_model(args: ModelArgs, device: str) -> Transformer:
"""Initialize model with proper device placement and dtype.""" """Initialize model with proper device placement and dtype."""
with torch.device(device): model = Transformer(args).to(TORCH_DTYPE)
model = Transformer(args)
model.to(TORCH_DTYPE)
model.eval() model.eval()
return model return model
@ -82,7 +81,7 @@ def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> t
if top_k > 0: if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('inf') logits[logits < v[:, [-1]]] = float('-inf')
probs = torch.softmax(logits / max(temperature, 1e-5), dim=-1) probs = torch.softmax(logits / max(temperature, 1e-5), dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(1) return torch.multinomial(probs, num_samples=1).squeeze(1)
@ -112,13 +111,12 @@ def generate(
Returns: Returns:
List of generated token sequences List of generated token sequences
""" """
# Initialize generation state
batch_size = len(prompt_tokens) batch_size = len(prompt_tokens)
device = next(model.parameters()).device device = next(model.parameters()).device
max_seq_len = model.max_seq_len max_seq_len = model.max_seq_len
prompt_lens = [len(p) for p in prompt_tokens] prompt_lens = [len(p) for p in prompt_tokens]
# Validate input lengths # Adjust max_new_tokens based on input length
if max(prompt_lens) + max_new_tokens > max_seq_len: if max(prompt_lens) + max_new_tokens > max_seq_len:
logger.warning(f"Truncating sequence length to {max_seq_len}") logger.warning(f"Truncating sequence length to {max_seq_len}")
max_new_tokens = max_seq_len - max(prompt_lens) max_new_tokens = max_seq_len - max(prompt_lens)
@ -183,23 +181,20 @@ def interactive_loop(
while True: while True:
try: try:
# Distributed input handling # Distributed input handling
prompt = None
if world_size > 1: if world_size > 1:
if rank == 0: if rank == 0:
prompt = input("\nUser: ") prompt = input("\nUser: ")
dist.broadcast_object_list([prompt], src=0) dist.broadcast_object_list([prompt], src=0)
else: else:
prompt = None
dist.broadcast_object_list([prompt], src=0) dist.broadcast_object_list([prompt], src=0)
if prompt == "/exit":
break
else: else:
prompt = input("\nUser: ") prompt = input("\nUser: ")
# Command handling # Command handling
if prompt in ["/exit", "/clear"]:
if prompt == "/exit": if prompt == "/exit":
break break
if prompt == "/clear":
messages.clear() messages.clear()
logger.info("History cleared") logger.info("History cleared")
continue continue