mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-20 02:28:57 -04:00
Improved docstrings for better understanding of the functions. Added specific error messages for input validation. Kept the structure of the code while making it easier to read and follow. Ensured that all exceptions provide meaningful messages to the user.
221 lines
8.4 KiB
Python
221 lines
8.4 KiB
Python
from typing import Tuple
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from triton import Config
|
|
|
|
|
|
@triton.jit
|
|
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
"""
|
|
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
|
|
|
Args:
|
|
x_ptr (triton.Pointer): Pointer to the input tensor.
|
|
y_ptr (triton.Pointer): Pointer to the output tensor for quantized values.
|
|
s_ptr (triton.Pointer): Pointer to the output tensor for scaling factors.
|
|
BLOCK_SIZE (tl.constexpr): The size of the block to be processed.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
pid = tl.program_id(axis=0)
|
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
|
|
|
# Calculate scaling factor
|
|
s = tl.max(tl.abs(x)) / 448.0
|
|
y = x / s
|
|
y = y.to(y_ptr.dtype.element_ty)
|
|
|
|
tl.store(y_ptr + offs, y)
|
|
tl.store(s_ptr + pid, s)
|
|
|
|
|
|
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Quantizes the input tensor `x` using block-wise quantization.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor to be quantized. Must be contiguous and last dimension size divisible by `block_size`.
|
|
block_size (int, optional): The size of blocks for quantization. Default is 128.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
|
- A tensor of scaling factors with dtype `torch.float32`.
|
|
|
|
Raises:
|
|
ValueError: If input tensor is not contiguous or last dimension size is not divisible by block size.
|
|
"""
|
|
if not x.is_contiguous():
|
|
raise ValueError("Input tensor must be contiguous.")
|
|
if x.size(-1) % block_size != 0:
|
|
raise ValueError("Last dimension of input tensor must be divisible by block_size.")
|
|
|
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
|
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
|
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
|
return y, s
|
|
|
|
|
|
@triton.jit
|
|
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|
"""
|
|
Dequantizes weights using the provided scaling factors and stores the result.
|
|
|
|
Args:
|
|
x_ptr (triton.Pointer): Pointer to the quantized weights.
|
|
s_ptr (triton.Pointer): Pointer to the scaling factors.
|
|
y_ptr (triton.Pointer): Pointer to the output buffer for dequantized weights.
|
|
M (int): Number of rows in the weight matrix.
|
|
N (int): Number of columns in the weight matrix.
|
|
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_n = tl.program_id(axis=1)
|
|
n = tl.cdiv(N, BLOCK_SIZE)
|
|
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
offs = offs_m[:, None] * N + offs_n[None, :]
|
|
|
|
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
|
|
s = tl.load(s_ptr + pid_m * n + pid_n)
|
|
|
|
y = x * s
|
|
tl.store(y_ptr + offs, y, mask=mask)
|
|
|
|
|
|
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
|
"""
|
|
Dequantizes the given weight tensor using the provided scale tensor.
|
|
|
|
Args:
|
|
x (torch.Tensor): The quantized weight tensor of shape (M, N).
|
|
s (torch.Tensor): The scale tensor of shape (M, N).
|
|
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
|
|
|
|
Returns:
|
|
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
|
|
|
|
Raises:
|
|
ValueError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
|
"""
|
|
if not x.is_contiguous() or not s.is_contiguous():
|
|
raise ValueError("Both x and s must be contiguous.")
|
|
if x.dim() != 2 or s.dim() != 2:
|
|
raise ValueError("Both x and s must be 2-dimensional tensors.")
|
|
|
|
M, N = x.size()
|
|
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
|
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']))
|
|
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
|
return y
|
|
|
|
|
|
# Configuration for autotuning FP8 GEMM implementation
|
|
fp8_gemm_configs = [
|
|
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
|
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
|
]
|
|
|
|
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
|
@triton.jit
|
|
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
|
a_s_ptr, b_s_ptr,
|
|
M, N: tl.constexpr, K: tl.constexpr,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr):
|
|
"""
|
|
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
|
|
|
Args:
|
|
a_ptr (triton.Pointer): Pointer to the first input matrix A.
|
|
b_ptr (triton.Pointer): Pointer to the second input matrix B.
|
|
c_ptr (triton.Pointer): Pointer to the output matrix C.
|
|
a_s_ptr (triton.Pointer): Pointer to the scaling factors for matrix A.
|
|
b_s_ptr (triton.Pointer): Pointer to the scaling factors for matrix B.
|
|
M (int): Number of rows in matrix A and C.
|
|
N (tl.constexpr): Number of columns in matrix B and C.
|
|
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
|
|
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
|
|
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
|
|
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_n = tl.program_id(axis=1)
|
|
k = tl.cdiv(K, BLOCK_SIZE_K)
|
|
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
|
|
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
|
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
|
|
a_s_ptrs = a_s_ptr + offs_m * k
|
|
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
for i in range(k):
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
|
|
a_s = tl.load(a_s_ptrs)
|
|
b_s = tl.load(b_s_ptrs)
|
|
|
|
# Update accumulator with scaled dot product
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
|
|
# Move pointers for the next iteration
|
|
a_ptrs += BLOCK_SIZE_K
|
|
b_ptrs += BLOCK_SIZE_K
|
|
a_s_ptrs += 1
|
|
b_s_ptrs += 1
|
|
|
|
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
|
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=mask)
|
|
|
|
|
|
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Perform matrix multiplication using FP8 precision.
|
|
|
|
Args:
|
|
a (torch.Tensor): First input matrix, must be contiguous.
|
|
a_s (torch.Tensor): Scaling factor for the first input matrix, must be contiguous.
|
|
b (torch.Tensor): Second input matrix, must be contiguous.
|
|
b_s (torch.Tensor): Scaling factor for the second input matrix, must be contiguous.
|
|
|
|
Returns:
|
|
torch.Tensor: Result of the matrix multiplication.
|
|
|
|
Raises:
|
|
ValueError: If any input tensor is not contiguous.
|
|
"""
|
|
if not a.is_contiguous() or not b.is_contiguous():
|
|
raise ValueError("Matrices a and b must be contiguous.")
|
|
if not a_s.is_contiguous() or not b_s.is_contiguous():
|
|
raise ValueError("Scaling factors a_s and b_s must be contiguous.")
|
|
|
|
K = a.size(-1)
|
|
M = a.numel() // K
|
|
N = b.size(0)
|
|
|
|
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
|
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
|
return c
|