DeepSeek-V3/inference/kernel.py

98 lines
3.5 KiB
Python

import torch
import torch.nn.functional as F
import logging
from typing import Optional, Tuple, Union
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def top_k_top_p_filtering(logits: torch.Tensor, top_k: int = 0, top_p: float = 1.0) -> torch.Tensor:
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
Args:
logits (torch.Tensor): The logits distribution of shape (vocab_size,).
top_k (int): Keep only top k tokens with highest probability (0 = no filtering).
top_p (float): Keep the top tokens with cumulative probability >= top_p.
Returns:
torch.Tensor: Filtered logits.
"""
if top_k > 0:
values, indices = torch.topk(logits, top_k)
min_values = values[:, -1].unsqueeze(-1)
logits = torch.where(logits < min_values, torch.tensor(float('-inf')).to(logits.device), logits)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[0, indices_to_remove] = float('-inf')
return logits
def decode(
input_ids: torch.Tensor,
position: int,
model: torch.nn.Module,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
apply_softmax: bool = False,
top_k: int = 0,
top_p: float = 1.0,
device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu'
) -> torch.Tensor:
"""
Decodes the next token's logits (or probabilities) from the model.
Args:
input_ids (torch.Tensor): Tokenized input sequence of shape (1, seq_len).
position (int): The current position (token index) in generation.
model (torch.nn.Module): Transformer model used for decoding.
past_key_values (Tuple, optional): Cached keys/values for speedup (default: None).
apply_softmax (bool): Whether to return softmax probabilities instead of raw logits.
top_k (int): Top-K filtering for logits (0 = disable).
top_p (float): Top-P (nucleus) filtering (1.0 = disable).
device (str | torch.device): Device to run inference on.
Returns:
torch.Tensor: Logits or probabilities for next-token prediction.
"""
input_ids = input_ids.to(device)
if past_key_values:
past_key_values = tuple(pk.to(device) for pk in past_key_values)
logger.info(f"🧠 [decode] Running inference at position: {position}")
logger.debug(f"📥 input_ids shape: {input_ids.shape}")
logger.debug(f"🔁 past_key_values: {'Provided' if past_key_values else 'None'}")
with torch.no_grad():
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
)
logits = outputs.logits[:, -1, :] # shape: (1, vocab_size)
logger.debug(f"📤 Raw logits shape: {logits.shape}")
# Apply filtering
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
if apply_softmax:
probs = F.softmax(logits, dim=-1)
logger.info(f"✅ Returned softmax probabilities.")
return probs
logger.info(f"✅ Returned raw logits.")
return logits
print("kernel.py loaded")
print("act_quant defined:", "act_quant" in dir())