Update generate.py

This commit is contained in:
Ivan Lloyd Roquero 2025-01-31 01:19:43 +08:00 committed by GitHub
parent b5d872ead0
commit daba5c1f78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,21 +11,33 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
def sample(logits, temperature: float = 1.0, top_k: int = 40, p: float = 0.9):
"""
Samples a token from the logits using temperature scaling.
Samples a token from the logits using temperature scaling, top-k, or nucleus sampling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
top_k (int, optional): Top-k for sampling. Defaults to 40.
p (float, optional): Nucleus sampling threshold. Defaults to 0.9.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
if top_k > 0:
top_k_values, top_k_indices = torch.topk(logits, top_k)
top_k_probs = torch.softmax(top_k_values, dim=-1)
next_token = top_k_indices[torch.multinomial(top_k_probs, 1)].squeeze()
elif p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_keep = sorted_indices[cumulative_probs <= p]
next_token = sorted_indices_to_keep[torch.multinomial(torch.softmax(sorted_logits, dim=-1), 1)].squeeze()
else:
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1).squeeze()
return next_token
@torch.inference_mode()
def generate(