DeepSeek-V3/inference/cpu_kernel.py

187 lines
6.7 KiB
Python

import torch
import torch.nn.functional as F
from typing import Tuple, Optional
def act_quant_cpu(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
CPU-compatible version of act_quant. Quantizes the input tensor using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized.
block_size (int, optional): The size of the blocks for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scaling factors
"""
assert x.is_contiguous(), 'Input tensor must be contiguous'
# Handle non-divisible cases more gracefully
if x.size(-1) % block_size != 0:
# Pad the tensor to make it divisible by block_size
pad_size = block_size - (x.size(-1) % block_size)
x = F.pad(x, (0, pad_size))
# Reshape to blocks for efficient processing
shape = x.shape
x_reshaped = x.reshape(-1, block_size)
# Calculate scaling factors (max absolute value in each block)
s = torch.max(torch.abs(x_reshaped), dim=1, keepdim=True)[0] / 448.0
# Avoid division by zero
s = torch.clamp(s, min=1e-10)
# Quantize by dividing by scaling factors
y = x_reshaped / s
# Either use float8 if available or simulate with int8 + scaling
if hasattr(torch, "float8_e4m3fn"):
y = y.to(torch.float8_e4m3fn)
else:
# Simulate float8 with int8 quantization
y = torch.clamp(y, -448.0, 448.0)
y = (y / 448.0 * 127).round().to(torch.int8)
# Reshape back to original shape
y = y.reshape(shape)
# Reshape scaling factors to match expected output format
s = s.reshape(*shape[:-1], -1).squeeze(-1)
return y, s
def weight_dequant_cpu(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
CPU-compatible version of weight_dequant. Dequantizes the weight tensor.
Args:
weight (torch.Tensor): Quantized weight tensor.
scale (torch.Tensor): Scaling factors.
Returns:
torch.Tensor: Dequantized weight tensor.
"""
# Handle different quantization formats
if weight.dtype == torch.int8:
# For int8 simulated quantization
weight_float = weight.to(torch.float32) / 127.0 * 448.0
elif hasattr(torch, "float8_e4m3fn") and weight.dtype == torch.float8_e4m3fn:
# Native float8 support
weight_float = weight.to(torch.float32)
else:
# Already in a floating point format
weight_float = weight.to(torch.float32)
# Reshape scale to broadcast correctly
if weight.dim() == 2:
# For linear layers
out_features, in_features = weight.shape
block_size = 128 # Same as in the original code
scale_out = (out_features + block_size - 1) // block_size
scale_in = (in_features + block_size - 1) // block_size
if scale.numel() == scale_out * scale_in:
# Reshape to match weight blocks
scale_reshaped = scale.reshape(scale_out, scale_in)
# Create a mask for each block
out_blocks = torch.arange(out_features).reshape(-1, 1) // block_size
in_blocks = torch.arange(in_features).reshape(1, -1) // block_size
# Limit to actual dimensions
out_blocks = torch.clamp(out_blocks, max=scale_out-1)
in_blocks = torch.clamp(in_blocks, max=scale_in-1)
# Get corresponding scale for each position
scale_broadcast = scale_reshaped[out_blocks, in_blocks]
# Apply scaling
return weight_float * scale_broadcast
# Fallback for other tensor dimensions
return weight_float * scale
def fp8_gemm_cpu(x: torch.Tensor, x_scale: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
"""
CPU-compatible version of fp8_gemm. Performs matrix multiplication with quantized tensors.
Args:
x (torch.Tensor): Input activations (quantized).
x_scale (torch.Tensor): Scaling factors for input activations.
weight (torch.Tensor): Weights (quantized).
weight_scale (torch.Tensor): Scaling factors for weights.
Returns:
torch.Tensor: Result of matrix multiplication.
"""
# Dequantize input and weights
if x.dtype == torch.int8:
x_float = x.to(torch.float32) / 127.0 * 448.0
elif hasattr(torch, "float8_e4m3fn") and x.dtype == torch.float8_e4m3fn:
x_float = x.to(torch.float32)
else:
x_float = x
# Apply input scaling
if x_scale is not None:
# Reshape x_scale for broadcasting
new_shape = list(x_scale.shape) + [1] * (x_float.dim() - x_scale.dim())
x_float = x_float * x_scale.reshape(*new_shape)
# Dequantize weights
weight_dequant = weight_dequant_cpu(weight, weight_scale)
# Perform matrix multiplication
result = F.linear(x_float, weight_dequant)
return result
# MPS (Metal Performance Shaders) optimized versions for Apple Silicon
def setup_mps_kernels():
"""
Set up optimized MPS kernels if running on Apple Silicon
"""
if hasattr(torch, "mps") and torch.mps.is_available():
print("Setting up MPS optimized kernels for Apple Silicon")
# MPS already optimizes most operations automatically
# Additional optimizations could be added in the future
else:
print("MPS not available, using CPU kernels")
# Provide unified interface that selects the appropriate implementation
def get_optimized_kernels(device="cpu"):
"""
Returns optimized kernel functions based on the device
Args:
device (str): The device to optimize for ("cpu", "mps", or "cuda")
Returns:
dict: Dictionary of optimized kernel functions
"""
if device == "mps" and hasattr(torch, "mps") and torch.mps.is_available():
setup_mps_kernels()
# For MPS, we use the CPU implementations which will be automatically
# optimized by PyTorch's MPS backend
return {
"act_quant": act_quant_cpu,
"weight_dequant": weight_dequant_cpu,
"fp8_gemm": fp8_gemm_cpu
}
elif device == "cuda" and torch.cuda.is_available():
# For CUDA, use the original implementations
from kernel import act_quant, weight_dequant, fp8_gemm
return {
"act_quant": act_quant,
"weight_dequant": weight_dequant,
"fp8_gemm": fp8_gemm
}
else:
# Default to CPU implementations
return {
"act_quant": act_quant_cpu,
"weight_dequant": weight_dequant_cpu,
"fp8_gemm": fp8_gemm_cpu
}