mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 18:18:57 -04:00
98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import logging
|
|
from typing import Optional, Tuple, Union
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def top_k_top_p_filtering(logits: torch.Tensor, top_k: int = 0, top_p: float = 1.0) -> torch.Tensor:
|
|
"""
|
|
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
|
|
|
|
Args:
|
|
logits (torch.Tensor): The logits distribution of shape (vocab_size,).
|
|
top_k (int): Keep only top k tokens with highest probability (0 = no filtering).
|
|
top_p (float): Keep the top tokens with cumulative probability >= top_p.
|
|
|
|
Returns:
|
|
torch.Tensor: Filtered logits.
|
|
"""
|
|
if top_k > 0:
|
|
values, indices = torch.topk(logits, top_k)
|
|
min_values = values[:, -1].unsqueeze(-1)
|
|
logits = torch.where(logits < min_values, torch.tensor(float('-inf')).to(logits.device), logits)
|
|
|
|
if top_p < 1.0:
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
|
sorted_indices_to_remove[:, 0] = 0
|
|
|
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
|
logits[0, indices_to_remove] = float('-inf')
|
|
|
|
return logits
|
|
|
|
|
|
def decode(
|
|
input_ids: torch.Tensor,
|
|
position: int,
|
|
model: torch.nn.Module,
|
|
past_key_values: Optional[Tuple[torch.Tensor]] = None,
|
|
apply_softmax: bool = False,
|
|
top_k: int = 0,
|
|
top_p: float = 1.0,
|
|
device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
) -> torch.Tensor:
|
|
"""
|
|
Decodes the next token's logits (or probabilities) from the model.
|
|
|
|
Args:
|
|
input_ids (torch.Tensor): Tokenized input sequence of shape (1, seq_len).
|
|
position (int): The current position (token index) in generation.
|
|
model (torch.nn.Module): Transformer model used for decoding.
|
|
past_key_values (Tuple, optional): Cached keys/values for speedup (default: None).
|
|
apply_softmax (bool): Whether to return softmax probabilities instead of raw logits.
|
|
top_k (int): Top-K filtering for logits (0 = disable).
|
|
top_p (float): Top-P (nucleus) filtering (1.0 = disable).
|
|
device (str | torch.device): Device to run inference on.
|
|
|
|
Returns:
|
|
torch.Tensor: Logits or probabilities for next-token prediction.
|
|
"""
|
|
input_ids = input_ids.to(device)
|
|
if past_key_values:
|
|
past_key_values = tuple(pk.to(device) for pk in past_key_values)
|
|
|
|
logger.info(f"🧠 [decode] Running inference at position: {position}")
|
|
logger.debug(f"📥 input_ids shape: {input_ids.shape}")
|
|
logger.debug(f"🔁 past_key_values: {'Provided' if past_key_values else 'None'}")
|
|
|
|
with torch.no_grad():
|
|
outputs = model(
|
|
input_ids=input_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=True,
|
|
)
|
|
|
|
logits = outputs.logits[:, -1, :] # shape: (1, vocab_size)
|
|
|
|
logger.debug(f"📤 Raw logits shape: {logits.shape}")
|
|
|
|
# Apply filtering
|
|
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
|
|
|
if apply_softmax:
|
|
probs = F.softmax(logits, dim=-1)
|
|
logger.info(f"✅ Returned softmax probabilities.")
|
|
return probs
|
|
|
|
logger.info(f"✅ Returned raw logits.")
|
|
return logits
|
|
print("kernel.py loaded")
|
|
print("act_quant defined:", "act_quant" in dir())
|