DeepSeek-V3/inference/generate.py
2025-02-08 09:14:14 +05:30

318 lines
9.9 KiB
Python

import os
import json
from argparse import ArgumentParser
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
@dataclass
class GenerationConfig:
max_new_tokens: int
temperature: float
eos_id: int
class TokenSampler:
@staticmethod
def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float): Temperature for scaling logits.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
class TextGenerator:
def __init__(self, model: Transformer, tokenizer: Any):
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
config: GenerationConfig
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens.
Args:
prompt_tokens: A list of lists containing the prompt tokens for each sequence.
config: Generation configuration parameters.
Returns:
List[List[int]]: Generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
if max(prompt_lens) > self.model.max_seq_len:
raise ValueError(f"Prompt length exceeds model maximum sequence length (max_seq_len={self.model.max_seq_len})")
total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens))
tokens = self._initialize_tokens(prompt_tokens, total_len)
completion_tokens = self._generate_tokens(
tokens, prompt_lens, total_len, config
)
return completion_tokens
def _initialize_tokens(
self, prompt_tokens: List[List[int]], total_len: int
) -> torch.Tensor:
tokens = torch.full(
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
return tokens
def _generate_tokens(
self,
tokens: torch.Tensor,
prompt_lens: List[int],
total_len: int,
config: GenerationConfig
) -> List[List[int]]:
prev_pos = 0
finished = torch.tensor([False] * len(prompt_lens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
next_token = self._get_next_token(logits, config.temperature)
next_token = torch.where(
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(
~prompt_mask[:, cur_pos], next_token == config.eos_id
)
prev_pos = cur_pos
if finished.all():
break
return self._process_completion_tokens(
tokens, prompt_lens, config.max_new_tokens, config.eos_id
)
def _get_next_token(
self, logits: torch.Tensor, temperature: float
) -> torch.Tensor:
if temperature > 0:
return TokenSampler.sample(logits, temperature)
return logits.argmax(dim=-1)
def _process_completion_tokens(
self,
tokens: torch.Tensor,
prompt_lens: List[int],
max_new_tokens: int,
eos_id: int
) -> List[List[int]]:
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
class DistributedEnvironment:
def __init__(self):
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
self.rank = int(os.getenv("RANK", "0"))
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
def setup(self):
if self.world_size > 1:
dist.init_process_group("nccl")
if self.rank != 0:
global print
print = lambda *_, **__: None
torch.cuda.set_device(self.local_rank)
def cleanup(self):
if self.world_size > 1:
dist.destroy_process_group()
def broadcast_prompt(self, prompt: Optional[str] = None) -> str:
if self.world_size == 1:
return input(">>> ")
elif self.rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
return prompt
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
return objects[0]
class ChatSession:
def __init__(
self,
generator: TextGenerator,
config: GenerationConfig,
dist_env: DistributedEnvironment
):
self.generator = generator
self.config = config
self.dist_env = dist_env
self.messages = []
def run_interactive(self):
while True:
prompt = self.dist_env.broadcast_prompt()
if prompt == "/exit":
break
elif prompt == "/clear":
self.messages.clear()
continue
completion = self._process_message(prompt)
print(completion)
self.messages.append({"role": "assistant", "content": completion})
def run_batch(self, input_file: str):
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
if len(prompts) > self.generator.model.args.max_batch_size:
raise ValueError(f"Number of prompts exceeds maximum batch size ({self.generator.model.args.max_batch_size})")
completions = self._process_batch(prompts)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
def _process_message(self, prompt: str) -> str:
self.messages.append({"role": "user", "content": prompt})
prompt_tokens = self.generator.tokenizer.apply_chat_template(
self.messages, add_generation_prompt=True
)
completion_tokens = self.generator.generate(
[prompt_tokens], self.config
)
return self.generator.tokenizer.decode(
completion_tokens[0], skip_special_tokens=True
)
def _process_batch(self, prompts: List[str]) -> List[str]:
prompt_tokens = [
self.generator.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True
)
for prompt in prompts
]
completion_tokens = self.generator.generate(
prompt_tokens, self.config
)
return self.generator.tokenizer.batch_decode(
completion_tokens, skip_special_tokens=True
)
def initialize_model(
ckpt_path: str, config_path: str, dist_env: DistributedEnvironment
) -> Tuple[Transformer, Any]:
"""Initialize the model and tokenizer."""
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config_path) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
# Warmup
tokenizer.decode(
TextGenerator(model, tokenizer).generate(
[tokenizer.encode("DeepSeek")],
GenerationConfig(max_new_tokens=2, temperature=1.0, eos_id=-1)
)[0]
)
load_model(
model,
os.path.join(
ckpt_path,
f"model{dist_env.rank}-mp{dist_env.world_size}.safetensors"
)
)
return model, tokenizer
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
dist_env = DistributedEnvironment()
dist_env.setup()
model, tokenizer = initialize_model(ckpt_path, config, dist_env)
generator = TextGenerator(model, tokenizer)
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
eos_id=tokenizer.eos_token_id
)
session = ChatSession(generator, gen_config, dist_env)
if interactive:
session.run_interactive()
else:
session.run_batch(input_file)
dist_env.cleanup()
if __name__ == "__main__":
parser = ArgumentParser(description="Distributed text generation system")
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
if not args.input_file and not args.interactive:
raise ValueError("Either input-file or interactive mode must be specified")
main(
args.ckpt_path,
args.config,
args.input_file,
args.interactive,
args.max_new_tokens,
args.temperature
)