DeepSeek-V3/inference/generate.py
Utsav-pal 38333fb817
Update generate.py: Add parallel processing for token generation
vThis update introduces parallel processing for token generation using torch.multiprocessing.Pool.
The new implementation improves inference speed by processing multiple sequences concurrently.
- Added the generate_parallel() function for parallel token generation.
- Used multiprocessing to distribute the workload across multiple processes, allowing for faster generation of tokens for multiple prompts.
- The generate_single_sequence() function was added to handle individual sequence generation logic, which is called by each worker in parallel.
- The num_workers parameter is introduced to control the number of worker processes (default is 4).
- Model is shared across processes for efficient memory usage.

These changes are particularly beneficial for batch processing or multi-prompt generation scenarios where multiple sequences need to be generated simultaneously.
2025-01-28 23:54:11 +05:30

198 lines
7.7 KiB
Python

import os
import json
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate_single_sequence(args):
"""
Generates tokens for a single sequence.
Args:
args: Tuple containing (model, tokens, max_new_tokens, eos_id, temperature)
Returns:
List of generated tokens.
"""
model, tokens, max_new_tokens, eos_id, temperature = args
total_len = min(model.max_seq_len, max_new_tokens + tokens.shape[1])
tokens = torch.cat([tokens, torch.full((1, total_len - tokens.shape[1]), -1, dtype=torch.long, device="cuda")], dim=1)
prev_pos = tokens.shape[1] - max_new_tokens
finished = torch.tensor([False], device="cuda")
for cur_pos in range(prev_pos, total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
next_token = sample(logits, temperature) if temperature > 0 else logits.argmax(dim=-1)
tokens[:, cur_pos] = next_token
finished |= next_token == eos_id
if finished.all():
break
generated_tokens = tokens.tolist()[0]
return generated_tokens[tokens.shape[1] - max_new_tokens :]
@torch.inference_mode()
def generate_parallel(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0,
num_workers: int = 4
) -> List[List[int]]:
"""
Parallelized token generation using multiprocessing.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
num_workers (int, optional): Number of worker processes for parallel generation.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
model.share_memory() # Make the model shareable across processes
tokens_list = [torch.tensor(t, dtype=torch.long, device="cuda").unsqueeze(0) for t in prompt_tokens]
args_list = [(model, tokens, max_new_tokens, eos_id, temperature) for tokens in tokens_list]
with mp.Pool(num_workers) as pool:
results = pool.map(generate_single_sequence, args_list)
return results
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate_parallel(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature, num_workers=4)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
assert args.input_file or args.interactive
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)