This commit is contained in:
Abdur Rahman 2025-04-09 09:34:59 +06:00 committed by GitHub
commit b23692c394
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,178 +1,122 @@
import os import os
import json import json
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import List from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformers import AutoTokenizer from transformers import AutoTokenizer
from safetensors.torch import load_model from safetensors.torch import load_model
from model import Transformer, ModelArgs from model import Transformer, ModelArgs
def sample(logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None) -> torch.Tensor:
def sample(logits, temperature: float = 1.0): if temperature <= 1e-5:
""" return logits.argmax(dim=-1)
Samples a token from the logits using temperature scaling. logits = logits / temperature
if top_k is not None:
Args: v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits (torch.Tensor): The logits tensor for token predictions. logits[logits < v[:, [-1]]] = -float('Inf')
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0. if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
Returns: cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
torch.Tensor: The sampled token. remove_mask = cum_probs > top_p
""" remove_mask[..., 1:] = remove_mask[..., :-1].clone()
logits = logits / max(temperature, 1e-5) remove_mask[..., 0] = False
probs = torch.softmax(logits, dim=-1) remove_indices = remove_mask.scatter(-1, sorted_indices, remove_mask)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) logits[remove_indices] = -float('Inf')
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10))
return (logits + gumbel_noise).argmax(dim=-1)
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(model: Transformer, prompt_tokens: List[List[int]], max_new_tokens: int, eos_id: int, temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None) -> List[List[int]]:
model: Transformer, model.reset_cache()
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens] prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens): for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") tokens[i, :len(t)] = torch.tensor(t, device="cuda")
prev_pos = 0 prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") finished = torch.zeros(len(prompt_tokens), dtype=torch.bool, device="cuda")
prompt_mask = tokens != -1 prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len): for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0: next_token = sample(logits, temperature, top_k, top_p)
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) finished |= (~prompt_mask[:, cur_pos] & (next_token == eos_id))
prev_pos = cur_pos prev_pos = cur_pos
if finished.all(): if finished.all():
break break
completion_tokens = [] completions = []
for i, toks in enumerate(tokens.tolist()): for i, seq in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] seq = seq[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks: completions.append(seq[:seq.index(eos_id)] if eos_id in seq else seq)
toks = toks[:toks.index(eos_id)] return completions
completion_tokens.append(toks)
return completion_tokens
def main(ckpt_path: str, config: str, input_file: str = "", interactive: bool = True, max_new_tokens: int = 100, temperature: float = 0.2, top_k: Optional[int] = None, top_p: Optional[float] = None) -> None:
def main( if not os.path.isdir(ckpt_path):
ckpt_path: str, raise FileNotFoundError(f"Checkpoint directory missing: {ckpt_path}")
config: str, if not os.path.isfile(config):
input_file: str = "", raise FileNotFoundError(f"Config file missing: {config}")
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0")) rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1: if world_size > 1:
dist.init_process_group("nccl") dist.init_process_group("nccl", init_method="env://")
global print
if rank != 0: if rank != 0:
print = lambda *_, **__: None print = lambda *_, **__: None
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965) torch.manual_seed(965)
with open(config) as f: with open(config) as f:
args = ModelArgs(**json.load(f)) model_args = ModelArgs(**json.load(f))
print(args) model = Transformer(model_args).to(torch.bfloat16).cuda()
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
if interactive: if interactive:
messages = [] messages = []
while True: while True:
if world_size == 1: prompt = get_input(rank, world_size)
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit": if prompt == "/exit":
break break
elif prompt == "/clear": if prompt == "/clear":
messages.clear() messages.clear()
continue continue
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) try:
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)
except Exception as e:
print(f"Tokenization error: {e}")
continue
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature, top_k, top_p)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion) print(completion)
messages.append({"role": "assistant", "content": completion}) messages.append({"role": "assistant", "content": completion})
else: else:
with open(input_file) as f: with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()] prompts = [line.strip() for line in f if line.strip()]
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})" batch_size = model_args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] completions = []
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) for i in range(0, len(prompts), batch_size):
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) batch_prompts = prompts[i:i+batch_size]
batch_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": p}], add_generation_prompt=True) for p in batch_prompts]
completion_tokens = generate(model, batch_tokens, max_new_tokens, tokenizer.eos_token_id, temperature, top_k, top_p)
completions.extend(tokenizer.batch_decode(completion_tokens, skip_special_tokens=True))
for prompt, completion in zip(prompts, completions): for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt) print(f"Prompt: {prompt}\nCompletion: {completion}\n{'-'*50}")
print("Completion:", completion)
print()
if world_size > 1: if world_size > 1:
dist.destroy_process_group() dist.destroy_process_group()
def get_input(rank: int, world_size: int) -> str:
if world_size == 1 or rank == 0:
prompt = input(">>> ")
if world_size > 1:
dist.broadcast_object_list([prompt], src=0)
return prompt
else:
res = [None]
dist.broadcast_object_list(res, src=0)
return res[0]
if __name__ == "__main__": if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True) parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True) parser.add_argument("--config", type=str, required=True)
@ -180,6 +124,8 @@ if __name__ == "__main__":
parser.add_argument("--interactive", action="store_true") parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top-k", type=int, default=None)
parser.add_argument("--top-p", type=float, default=None)
args = parser.parse_args() args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature, args.top_k, args.top_p)