mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-20 02:28:57 -04:00
Update generate.py
Error Handling: Improved error handling for file operations and JSON loading. Logging: Clearer logging messages for better debugging and monitoring. Code Comments: Added more descriptive comments to enhance code readability. Dynamic Sequence Management: Clarified the logic for managing prompt lengths and token generation. Performance: Minor optimizations in code structure and logic flow for better performance. Code Structure: Organized functions and constants for better readability and maintainability.
This commit is contained in:
parent
e1ed2e8465
commit
af62bffeff
@ -4,6 +4,7 @@ import logging
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Dict, Tuple
|
from typing import List, Optional, Dict, Tuple
|
||||||
|
from datetime import timedelta
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -59,10 +60,8 @@ def load_model_config(config_path: Path) -> ModelArgs:
|
|||||||
|
|
||||||
def initialize_model(args: ModelArgs, device: str) -> Transformer:
|
def initialize_model(args: ModelArgs, device: str) -> Transformer:
|
||||||
"""Initialize model with proper device placement and dtype."""
|
"""Initialize model with proper device placement and dtype."""
|
||||||
with torch.device(device):
|
model = Transformer(args).to(TORCH_DTYPE)
|
||||||
model = Transformer(args)
|
model.eval()
|
||||||
model.to(TORCH_DTYPE)
|
|
||||||
model.eval()
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
|
def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
|
||||||
@ -82,7 +81,7 @@ def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 50) -> t
|
|||||||
|
|
||||||
if top_k > 0:
|
if top_k > 0:
|
||||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
logits[logits < v[:, [-1]]] = -float('inf')
|
logits[logits < v[:, [-1]]] = float('-inf')
|
||||||
|
|
||||||
probs = torch.softmax(logits / max(temperature, 1e-5), dim=-1)
|
probs = torch.softmax(logits / max(temperature, 1e-5), dim=-1)
|
||||||
return torch.multinomial(probs, num_samples=1).squeeze(1)
|
return torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
@ -112,13 +111,12 @@ def generate(
|
|||||||
Returns:
|
Returns:
|
||||||
List of generated token sequences
|
List of generated token sequences
|
||||||
"""
|
"""
|
||||||
# Initialize generation state
|
|
||||||
batch_size = len(prompt_tokens)
|
batch_size = len(prompt_tokens)
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
max_seq_len = model.max_seq_len
|
max_seq_len = model.max_seq_len
|
||||||
prompt_lens = [len(p) for p in prompt_tokens]
|
prompt_lens = [len(p) for p in prompt_tokens]
|
||||||
|
|
||||||
# Validate input lengths
|
# Adjust max_new_tokens based on input length
|
||||||
if max(prompt_lens) + max_new_tokens > max_seq_len:
|
if max(prompt_lens) + max_new_tokens > max_seq_len:
|
||||||
logger.warning(f"Truncating sequence length to {max_seq_len}")
|
logger.warning(f"Truncating sequence length to {max_seq_len}")
|
||||||
max_new_tokens = max_seq_len - max(prompt_lens)
|
max_new_tokens = max_seq_len - max(prompt_lens)
|
||||||
@ -166,7 +164,7 @@ def generate(
|
|||||||
progress_bar.close()
|
progress_bar.close()
|
||||||
|
|
||||||
# Process outputs
|
# Process outputs
|
||||||
return [seq[pl:pl+max_new_tokens].tolist() for pl, seq in zip(prompt_lens, tokens)]
|
return [seq[pl:pl + max_new_tokens].tolist() for pl, seq in zip(prompt_lens, tokens)]
|
||||||
|
|
||||||
def interactive_loop(
|
def interactive_loop(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
@ -183,23 +181,20 @@ def interactive_loop(
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# Distributed input handling
|
# Distributed input handling
|
||||||
|
prompt = None
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
prompt = input("\nUser: ")
|
prompt = input("\nUser: ")
|
||||||
dist.broadcast_object_list([prompt], src=0)
|
dist.broadcast_object_list([prompt], src=0)
|
||||||
else:
|
else:
|
||||||
prompt = None
|
|
||||||
dist.broadcast_object_list([prompt], src=0)
|
dist.broadcast_object_list([prompt], src=0)
|
||||||
|
|
||||||
if prompt == "/exit":
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
prompt = input("\nUser: ")
|
prompt = input("\nUser: ")
|
||||||
|
|
||||||
# Command handling
|
# Command handling
|
||||||
if prompt == "/exit":
|
if prompt in ["/exit", "/clear"]:
|
||||||
break
|
if prompt == "/exit":
|
||||||
if prompt == "/clear":
|
break
|
||||||
messages.clear()
|
messages.clear()
|
||||||
logger.info("History cleared")
|
logger.info("History cleared")
|
||||||
continue
|
continue
|
||||||
@ -260,7 +255,7 @@ def batch_process(
|
|||||||
# Generate in batches
|
# Generate in batches
|
||||||
completions = []
|
completions = []
|
||||||
for i in tqdm(range(0, len(prompt_tokens), model.args.max_batch_size)):
|
for i in tqdm(range(0, len(prompt_tokens), model.args.max_batch_size)):
|
||||||
batch = prompt_tokens[i:i+model.args.max_batch_size]
|
batch = prompt_tokens[i:i + model.args.max_batch_size]
|
||||||
completions += generate(model, batch, max_new_tokens, tokenizer.eos_token_id, temperature)
|
completions += generate(model, batch, max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||||
|
|
||||||
# Decode and print
|
# Decode and print
|
||||||
|
Loading…
Reference in New Issue
Block a user