From daba5c1f78885750c16181cdaa56324f710e7c02 Mon Sep 17 00:00:00 2001 From: Ivan Lloyd Roquero Date: Fri, 31 Jan 2025 01:19:43 +0800 Subject: [PATCH] Update generate.py --- inference/generate.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..9ae68cb 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -11,21 +11,33 @@ from safetensors.torch import load_model from model import Transformer, ModelArgs -def sample(logits, temperature: float = 1.0): +def sample(logits, temperature: float = 1.0, top_k: int = 40, p: float = 0.9): """ - Samples a token from the logits using temperature scaling. - + Samples a token from the logits using temperature scaling, top-k, or nucleus sampling. + Args: logits (torch.Tensor): The logits tensor for token predictions. temperature (float, optional): Temperature for scaling logits. Defaults to 1.0. + top_k (int, optional): Top-k for sampling. Defaults to 40. + p (float, optional): Nucleus sampling threshold. Defaults to 0.9. Returns: torch.Tensor: The sampled token. """ logits = logits / max(temperature, 1e-5) - probs = torch.softmax(logits, dim=-1) - return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) - + if top_k > 0: + top_k_values, top_k_indices = torch.topk(logits, top_k) + top_k_probs = torch.softmax(top_k_values, dim=-1) + next_token = top_k_indices[torch.multinomial(top_k_probs, 1)].squeeze() + elif p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_keep = sorted_indices[cumulative_probs <= p] + next_token = sorted_indices_to_keep[torch.multinomial(torch.softmax(sorted_logits, dim=-1), 1)].squeeze() + else: + probs = torch.softmax(logits, dim=-1) + next_token = torch.multinomial(probs, 1).squeeze() + return next_token @torch.inference_mode() def generate(