mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 06:08:58 -05:00
318 lines
9.9 KiB
Python
318 lines
9.9 KiB
Python
import os
|
|
import json
|
|
from argparse import ArgumentParser
|
|
from typing import List, Optional, Dict, Any, Tuple
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from transformers import AutoTokenizer
|
|
from safetensors.torch import load_model
|
|
|
|
from model import Transformer, ModelArgs
|
|
|
|
|
|
@dataclass
|
|
class GenerationConfig:
|
|
max_new_tokens: int
|
|
temperature: float
|
|
eos_id: int
|
|
|
|
|
|
class TokenSampler:
|
|
@staticmethod
|
|
def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
|
|
"""
|
|
Samples a token from the logits using temperature scaling.
|
|
|
|
Args:
|
|
logits (torch.Tensor): The logits tensor for token predictions.
|
|
temperature (float): Temperature for scaling logits.
|
|
|
|
Returns:
|
|
torch.Tensor: The sampled token.
|
|
"""
|
|
logits = logits / max(temperature, 1e-5)
|
|
probs = torch.softmax(logits, dim=-1)
|
|
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
|
|
|
|
|
class TextGenerator:
|
|
def __init__(self, model: Transformer, tokenizer: Any):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
self,
|
|
prompt_tokens: List[List[int]],
|
|
config: GenerationConfig
|
|
) -> List[List[int]]:
|
|
"""
|
|
Generates new tokens based on the given prompt tokens.
|
|
|
|
Args:
|
|
prompt_tokens: A list of lists containing the prompt tokens for each sequence.
|
|
config: Generation configuration parameters.
|
|
|
|
Returns:
|
|
List[List[int]]: Generated tokens for each sequence.
|
|
"""
|
|
prompt_lens = [len(t) for t in prompt_tokens]
|
|
if max(prompt_lens) > self.model.max_seq_len:
|
|
raise ValueError(f"Prompt length exceeds model maximum sequence length (max_seq_len={self.model.max_seq_len})")
|
|
|
|
total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens))
|
|
tokens = self._initialize_tokens(prompt_tokens, total_len)
|
|
|
|
completion_tokens = self._generate_tokens(
|
|
tokens, prompt_lens, total_len, config
|
|
)
|
|
return completion_tokens
|
|
|
|
def _initialize_tokens(
|
|
self, prompt_tokens: List[List[int]], total_len: int
|
|
) -> torch.Tensor:
|
|
tokens = torch.full(
|
|
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
|
|
)
|
|
for i, t in enumerate(prompt_tokens):
|
|
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
|
return tokens
|
|
|
|
def _generate_tokens(
|
|
self,
|
|
tokens: torch.Tensor,
|
|
prompt_lens: List[int],
|
|
total_len: int,
|
|
config: GenerationConfig
|
|
) -> List[List[int]]:
|
|
prev_pos = 0
|
|
finished = torch.tensor([False] * len(prompt_lens), device="cuda")
|
|
prompt_mask = tokens != -1
|
|
|
|
for cur_pos in range(min(prompt_lens), total_len):
|
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
next_token = self._get_next_token(logits, config.temperature)
|
|
next_token = torch.where(
|
|
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
|
)
|
|
|
|
tokens[:, cur_pos] = next_token
|
|
finished |= torch.logical_and(
|
|
~prompt_mask[:, cur_pos], next_token == config.eos_id
|
|
)
|
|
prev_pos = cur_pos
|
|
|
|
if finished.all():
|
|
break
|
|
|
|
return self._process_completion_tokens(
|
|
tokens, prompt_lens, config.max_new_tokens, config.eos_id
|
|
)
|
|
|
|
def _get_next_token(
|
|
self, logits: torch.Tensor, temperature: float
|
|
) -> torch.Tensor:
|
|
if temperature > 0:
|
|
return TokenSampler.sample(logits, temperature)
|
|
return logits.argmax(dim=-1)
|
|
|
|
def _process_completion_tokens(
|
|
self,
|
|
tokens: torch.Tensor,
|
|
prompt_lens: List[int],
|
|
max_new_tokens: int,
|
|
eos_id: int
|
|
) -> List[List[int]]:
|
|
completion_tokens = []
|
|
for i, toks in enumerate(tokens.tolist()):
|
|
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
|
|
if eos_id in toks:
|
|
toks = toks[:toks.index(eos_id)]
|
|
completion_tokens.append(toks)
|
|
return completion_tokens
|
|
|
|
|
|
class DistributedEnvironment:
|
|
def __init__(self):
|
|
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
|
self.rank = int(os.getenv("RANK", "0"))
|
|
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
|
|
|
def setup(self):
|
|
if self.world_size > 1:
|
|
dist.init_process_group("nccl")
|
|
if self.rank != 0:
|
|
global print
|
|
print = lambda *_, **__: None
|
|
torch.cuda.set_device(self.local_rank)
|
|
|
|
def cleanup(self):
|
|
if self.world_size > 1:
|
|
dist.destroy_process_group()
|
|
|
|
def broadcast_prompt(self, prompt: Optional[str] = None) -> str:
|
|
if self.world_size == 1:
|
|
return input(">>> ")
|
|
elif self.rank == 0:
|
|
prompt = input(">>> ")
|
|
objects = [prompt]
|
|
dist.broadcast_object_list(objects, 0)
|
|
return prompt
|
|
else:
|
|
objects = [None]
|
|
dist.broadcast_object_list(objects, 0)
|
|
return objects[0]
|
|
|
|
|
|
class ChatSession:
|
|
def __init__(
|
|
self,
|
|
generator: TextGenerator,
|
|
config: GenerationConfig,
|
|
dist_env: DistributedEnvironment
|
|
):
|
|
self.generator = generator
|
|
self.config = config
|
|
self.dist_env = dist_env
|
|
self.messages = []
|
|
|
|
def run_interactive(self):
|
|
while True:
|
|
prompt = self.dist_env.broadcast_prompt()
|
|
if prompt == "/exit":
|
|
break
|
|
elif prompt == "/clear":
|
|
self.messages.clear()
|
|
continue
|
|
|
|
completion = self._process_message(prompt)
|
|
print(completion)
|
|
self.messages.append({"role": "assistant", "content": completion})
|
|
|
|
def run_batch(self, input_file: str):
|
|
with open(input_file) as f:
|
|
prompts = [line.strip() for line in f.readlines()]
|
|
|
|
if len(prompts) > self.generator.model.args.max_batch_size:
|
|
raise ValueError(f"Number of prompts exceeds maximum batch size ({self.generator.model.args.max_batch_size})")
|
|
|
|
completions = self._process_batch(prompts)
|
|
for prompt, completion in zip(prompts, completions):
|
|
print("Prompt:", prompt)
|
|
print("Completion:", completion)
|
|
print()
|
|
|
|
def _process_message(self, prompt: str) -> str:
|
|
self.messages.append({"role": "user", "content": prompt})
|
|
prompt_tokens = self.generator.tokenizer.apply_chat_template(
|
|
self.messages, add_generation_prompt=True
|
|
)
|
|
completion_tokens = self.generator.generate(
|
|
[prompt_tokens], self.config
|
|
)
|
|
return self.generator.tokenizer.decode(
|
|
completion_tokens[0], skip_special_tokens=True
|
|
)
|
|
|
|
def _process_batch(self, prompts: List[str]) -> List[str]:
|
|
prompt_tokens = [
|
|
self.generator.tokenizer.apply_chat_template(
|
|
[{"role": "user", "content": prompt}],
|
|
add_generation_prompt=True
|
|
)
|
|
for prompt in prompts
|
|
]
|
|
completion_tokens = self.generator.generate(
|
|
prompt_tokens, self.config
|
|
)
|
|
return self.generator.tokenizer.batch_decode(
|
|
completion_tokens, skip_special_tokens=True
|
|
)
|
|
|
|
|
|
def initialize_model(
|
|
ckpt_path: str, config_path: str, dist_env: DistributedEnvironment
|
|
) -> Tuple[Transformer, Any]:
|
|
"""Initialize the model and tokenizer."""
|
|
torch.set_default_dtype(torch.bfloat16)
|
|
torch.set_num_threads(8)
|
|
torch.manual_seed(965)
|
|
|
|
with open(config_path) as f:
|
|
args = ModelArgs(**json.load(f))
|
|
print(args)
|
|
|
|
with torch.device("cuda"):
|
|
model = Transformer(args)
|
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
|
|
|
# Warmup
|
|
tokenizer.decode(
|
|
TextGenerator(model, tokenizer).generate(
|
|
[tokenizer.encode("DeepSeek")],
|
|
GenerationConfig(max_new_tokens=2, temperature=1.0, eos_id=-1)
|
|
)[0]
|
|
)
|
|
|
|
load_model(
|
|
model,
|
|
os.path.join(
|
|
ckpt_path,
|
|
f"model{dist_env.rank}-mp{dist_env.world_size}.safetensors"
|
|
)
|
|
)
|
|
return model, tokenizer
|
|
|
|
|
|
def main(
|
|
ckpt_path: str,
|
|
config: str,
|
|
input_file: str = "",
|
|
interactive: bool = True,
|
|
max_new_tokens: int = 100,
|
|
temperature: float = 1.0,
|
|
) -> None:
|
|
dist_env = DistributedEnvironment()
|
|
dist_env.setup()
|
|
|
|
model, tokenizer = initialize_model(ckpt_path, config, dist_env)
|
|
generator = TextGenerator(model, tokenizer)
|
|
gen_config = GenerationConfig(
|
|
max_new_tokens=max_new_tokens,
|
|
temperature=temperature,
|
|
eos_id=tokenizer.eos_token_id
|
|
)
|
|
|
|
session = ChatSession(generator, gen_config, dist_env)
|
|
|
|
if interactive:
|
|
session.run_interactive()
|
|
else:
|
|
session.run_batch(input_file)
|
|
|
|
dist_env.cleanup()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = ArgumentParser(description="Distributed text generation system")
|
|
parser.add_argument("--ckpt-path", type=str, required=True)
|
|
parser.add_argument("--config", type=str, required=True)
|
|
parser.add_argument("--input-file", type=str, default="")
|
|
parser.add_argument("--interactive", action="store_true")
|
|
parser.add_argument("--max-new-tokens", type=int, default=200)
|
|
parser.add_argument("--temperature", type=float, default=0.2)
|
|
args = parser.parse_args()
|
|
|
|
if not args.input_file and not args.interactive:
|
|
raise ValueError("Either input-file or interactive mode must be specified")
|
|
|
|
main(
|
|
args.ckpt_path,
|
|
args.config,
|
|
args.input_file,
|
|
args.interactive,
|
|
args.max_new_tokens,
|
|
args.temperature
|
|
) |