mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-22 21:58:58 -05:00
Update generate.py
This commit is contained in:
parent
b5d872ead0
commit
daba5c1f78
@ -11,21 +11,33 @@ from safetensors.torch import load_model
|
||||
from model import Transformer, ModelArgs
|
||||
|
||||
|
||||
def sample(logits, temperature: float = 1.0):
|
||||
def sample(logits, temperature: float = 1.0, top_k: int = 40, p: float = 0.9):
|
||||
"""
|
||||
Samples a token from the logits using temperature scaling.
|
||||
Samples a token from the logits using temperature scaling, top-k, or nucleus sampling.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The logits tensor for token predictions.
|
||||
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
|
||||
top_k (int, optional): Top-k for sampling. Defaults to 40.
|
||||
p (float, optional): Nucleus sampling threshold. Defaults to 0.9.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The sampled token.
|
||||
"""
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
if top_k > 0:
|
||||
top_k_values, top_k_indices = torch.topk(logits, top_k)
|
||||
top_k_probs = torch.softmax(top_k_values, dim=-1)
|
||||
next_token = top_k_indices[torch.multinomial(top_k_probs, 1)].squeeze()
|
||||
elif p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_keep = sorted_indices[cumulative_probs <= p]
|
||||
next_token = sorted_indices_to_keep[torch.multinomial(torch.softmax(sorted_logits, dim=-1), 1)].squeeze()
|
||||
else:
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, 1).squeeze()
|
||||
return next_token
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
|
Loading…
Reference in New Issue
Block a user