DeepSeek-V3/inference/optimized_generate.py

364 lines
14 KiB
Python

import os
import json
from argparse import ArgumentParser
from typing import List, Optional
import torch
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0,
device: str = "mps",
chunk_size: int = 512
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Optimized for MacBook with chunked processing and reduced memory usage.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
device (str, optional): The device to run generation on. Defaults to "mps".
chunk_size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
# Process in smaller batches for memory efficiency
batch_size = len(prompt_tokens)
if batch_size > 1:
print(f"Processing {batch_size} prompts in sequence to conserve memory...")
all_completion_tokens = []
for i in range(batch_size):
single_completion = generate(
model,
[prompt_tokens[i]],
max_new_tokens,
eos_id,
temperature,
device,
chunk_size
)
all_completion_tokens.extend(single_completion)
return all_completion_tokens
# Calculate total sequence length
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
# Initialize tokens tensor on appropriate device
tokens = torch.full((batch_size, total_len), -1, dtype=torch.long, device=device)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prev_pos = 0
finished = torch.tensor([False] * batch_size, device=device)
prompt_mask = tokens != -1
# Process in chunks for lower memory usage
for cur_pos in range(min(prompt_lens), total_len):
# Use a sliding window approach for KV cache efficiency
start_pos = max(0, cur_pos - chunk_size) if cur_pos > min(prompt_lens) else prev_pos
# Clear GPU/MPS cache periodically to prevent memory fragmentation
if cur_pos % 50 == 0 and device == "mps":
torch.mps.empty_cache()
elif cur_pos % 50 == 0 and device == "cuda":
torch.cuda.empty_cache()
logits = model.forward(tokens[:, start_pos:cur_pos], start_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = start_pos
if finished.all():
break
# Optional progress display for longer generations
if max_new_tokens > 100 and cur_pos % 20 == 0:
progress = (cur_pos - min(prompt_lens)) / max_new_tokens * 100
print(f"\rGenerating: {progress:.1f}% complete", end="")
if max_new_tokens > 100:
print("\rGeneration complete! ")
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
def get_optimal_device(force_cpu: bool = False) -> str:
"""
Determines the best available device for model inference.
Args:
force_cpu (bool, optional): Force CPU usage even if GPU is available. Defaults to False.
Returns:
str: Device string ("cuda", "mps", or "cpu")
"""
if force_cpu:
return "cpu"
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch, "mps") and torch.mps.is_available():
return "mps" # Apple Silicon GPU
else:
return "cpu"
def optimize_model(model: Transformer, quantize: bool = True, device: str = "cpu") -> Transformer:
"""
Applies optimizations to the model for more efficient inference.
Args:
model (Transformer): The transformer model to optimize.
quantize (bool, optional): Whether to apply int8 quantization. Defaults to True.
device (str, optional): Target device. Defaults to "cpu".
Returns:
Transformer: Optimized model
"""
if quantize and device == "cpu":
# Apply dynamic quantization to linear layers
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
return model
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
force_cpu: bool = False,
quantize: bool = True,
chunk_size: int = 512,
reduced_model: bool = False,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Optimized for MacBooks with various memory/performance options.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
force_cpu (bool, optional): Force CPU usage even if GPU is available. Defaults to False.
quantize (bool, optional): Apply quantization when possible. Defaults to True.
chunk_size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512.
reduced_model (bool, optional): Load a smaller version of the model if available. Defaults to False.
"""
# Detect optimal device
device = get_optimal_device(force_cpu)
print(f"Using device: {device}")
# Set appropriate torch settings
if device == "cuda":
torch.cuda.set_device(0)
# Use bfloat16 for CUDA/MPS, float32 for CPU
if device == "cpu":
torch.set_default_dtype(torch.float32)
else:
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8) # Adjust based on your CPU
torch.manual_seed(965)
# Load model configuration
with open(config) as f:
config_data = json.load(f)
# Apply optimizations to configuration for smaller/faster model
if reduced_model:
# Reduce the number of experts and heads
config_data["n_routed_experts"] = config_data.get("n_routed_experts", 64) // 2
config_data["max_seq_len"] = min(config_data.get("max_seq_len", 16384), 4096) # Reduce context size
args = ModelArgs(**config_data)
print(args)
# Load model
with torch.device(device):
model = Transformer(args)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
# Load the appropriate checkpoint
if device != "cuda":
# For CPU/MPS, always use rank 0
checkpoint_path = os.path.join(ckpt_path, "model0-mp1.safetensors")
else:
# For CUDA, can use multiple GPUs if available
world_size = min(torch.cuda.device_count(), 1) # Limit to 1 for simpler usage
rank = 0
checkpoint_path = os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")
print(f"Loading checkpoint from {checkpoint_path}")
load_model(model, checkpoint_path)
# Apply quantization and other optimizations
model = optimize_model(model, quantize=quantize, device=device)
model.to(device)
# Generate a quick test sequence to ensure everything is working
print("Running warmup generation...")
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1., device)[0])
print("Model loaded and ready!")
if interactive:
messages = []
while True:
prompt = input(">>> ")
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
# Show a waiting message for longer generations
if max_new_tokens > 50:
print("Generating response...")
completion_tokens = generate(
model,
[prompt_tokens],
max_new_tokens,
tokenizer.eos_token_id,
temperature,
device,
chunk_size
)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
# Clear cache after each generation to prevent memory buildup
if device == "mps":
torch.mps.empty_cache()
elif device == "cuda":
torch.cuda.empty_cache()
else:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(
model,
prompt_tokens,
max_new_tokens,
tokenizer.eos_token_id,
temperature,
device,
chunk_size
)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if __name__ == "__main__":
"""
Command-line interface for optimized text generation on MacBooks.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
--force-cpu (bool, optional): Force CPU usage even if GPU is available.
--no-quantize (bool, optional): Disable quantization (higher quality but slower).
--chunk-size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512.
--reduced-model (bool, optional): Load a smaller version of the model if available.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--force-cpu", action="store_true")
parser.add_argument("--no-quantize", action="store_true")
parser.add_argument("--chunk-size", type=int, default=512)
parser.add_argument("--reduced-model", action="store_true")
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(
args.ckpt_path,
args.config,
args.input_file,
args.interactive,
args.max_new_tokens,
args.temperature,
args.force_cpu,
not args.no_quantize,
args.chunk_size,
args.reduced_model
)