mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 10:08:59 -04:00
242 lines
8.6 KiB
Python
242 lines
8.6 KiB
Python
from typing import Tuple
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from triton import Config
|
|
|
|
|
|
@triton.jit
|
|
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|
"""
|
|
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
|
|
|
Args:
|
|
x_ptr (triton.Pointer): Pointer to the input tensor.
|
|
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
|
|
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
|
|
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
pid = tl.program_id(axis=0)
|
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
|
s = tl.max(tl.abs(x)) / 448.0
|
|
y = x / s
|
|
y = y.to(y_ptr.dtype.element_ty)
|
|
tl.store(y_ptr + offs, y)
|
|
tl.store(s_ptr + pid, s)
|
|
|
|
|
|
def act_quant(
|
|
x: torch.Tensor, block_size: int = 128
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Quantizes the input tensor `x` using block-wise quantization.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
|
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
|
|
|
Returns:
|
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
|
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
|
- A tensor of scaling factors with dtype `torch.float32`.
|
|
|
|
Raises:
|
|
ValueError: If the tensor is not contiguous or the last dimension is not divisible by `block_size`.
|
|
"""
|
|
if not x.is_contiguous():
|
|
raise ValueError("Input tensor must be contiguous in memory.")
|
|
if x.size(-1) % block_size != 0:
|
|
raise ValueError(
|
|
f"The last dimension of the tensor must be divisible by {block_size}."
|
|
)
|
|
|
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
|
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
|
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
|
return y, s
|
|
|
|
|
|
@triton.jit
|
|
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|
"""
|
|
Dequantizes weights using the provided scaling factors and stores the result.
|
|
|
|
Args:
|
|
x_ptr (tl.pointer): Pointer to the quantized weights.
|
|
s_ptr (tl.pointer): Pointer to the scaling factors.
|
|
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
|
|
M (int): Number of rows in the weight matrix.
|
|
N (int): Number of columns in the weight matrix.
|
|
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_n = tl.program_id(axis=1)
|
|
n = tl.cdiv(N, BLOCK_SIZE)
|
|
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
offs = offs_m[:, None] * N + offs_n[None, :]
|
|
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
|
|
s = tl.load(s_ptr + pid_m * n + pid_n)
|
|
y = x * s
|
|
tl.store(y_ptr + offs, y, mask=mask)
|
|
|
|
|
|
def weight_dequant(
|
|
x: torch.Tensor, s: torch.Tensor, block_size: int = 128
|
|
) -> torch.Tensor:
|
|
"""
|
|
Dequantizes the given weight tensor using the provided scale tensor.
|
|
|
|
Args:
|
|
x (torch.Tensor): The quantized weight tensor of shape (M, N).
|
|
s (torch.Tensor): The scale tensor of shape (M, N).
|
|
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
|
|
|
|
Returns:
|
|
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
|
|
|
|
Raises:
|
|
ValueError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
|
"""
|
|
if not x.is_contiguous():
|
|
raise ValueError("Input tensor `x` must be contiguous.")
|
|
if not s.is_contiguous():
|
|
raise ValueError("Input tensor `s` must be contiguous.")
|
|
if x.dim() != 2 or s.dim() != 2:
|
|
raise ValueError("Both `x` and `s` must be 2D tensors.")
|
|
|
|
M, N = x.size()
|
|
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
|
grid = lambda meta: (
|
|
triton.cdiv(M, meta["BLOCK_SIZE_M"]),
|
|
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
|
)
|
|
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
|
return y
|
|
|
|
|
|
fp8_gemm_configs = [
|
|
Config(
|
|
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
|
|
num_stages=num_stages,
|
|
num_warps=8,
|
|
)
|
|
for block_m in [16, 32, 64]
|
|
for block_n in [32, 64, 128]
|
|
for num_stages in [3, 4, 5, 6]
|
|
]
|
|
|
|
|
|
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
|
|
@triton.jit
|
|
def fp8_gemm_kernel(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
a_s_ptr,
|
|
b_s_ptr,
|
|
M,
|
|
N: tl.constexpr,
|
|
K: tl.constexpr,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
):
|
|
"""
|
|
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
|
|
|
Args:
|
|
a_ptr (tl.tensor): Pointer to the first input matrix A.
|
|
b_ptr (tl.tensor): Pointer to the second input matrix B.
|
|
c_ptr (tl.tensor): Pointer to the output matrix C.
|
|
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
|
|
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
|
|
M (int): Number of rows in matrix A and C.
|
|
N (tl.constexpr): Number of columns in matrix B and C.
|
|
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
|
|
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
|
|
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
|
|
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
pid_m = tl.program_id(axis=0)
|
|
pid_n = tl.program_id(axis=1)
|
|
k = tl.cdiv(K, BLOCK_SIZE_K)
|
|
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
|
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
|
|
a_s_ptrs = a_s_ptr + offs_m * k
|
|
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
for i in range(k):
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
|
|
a_s = tl.load(a_s_ptrs)
|
|
b_s = tl.load(b_s_ptrs)
|
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
|
a_ptrs += BLOCK_SIZE_K
|
|
b_ptrs += BLOCK_SIZE_K
|
|
a_s_ptrs += 1
|
|
b_s_ptrs += 1
|
|
c = accumulator.to(c_ptr.dtype.element_ty)
|
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
|
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=mask)
|
|
|
|
|
|
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
|
"""
|
|
Perform a matrix multiplication using FP8 precision.
|
|
|
|
Args:
|
|
a (torch.Tensor): The first input matrix, must be contiguous.
|
|
a_s (torch.Tensor): The scaling factors for matrix `a`.
|
|
b (torch.Tensor): The second input matrix, must be contiguous.
|
|
b_s (torch.Tensor): The scaling factors for matrix `b`.
|
|
|
|
Returns:
|
|
torch.Tensor: The resulting matrix after multiplication.
|
|
|
|
Raises:
|
|
ValueError: If `a`, `b`, `a_s`, or `b_s` are not contiguous.
|
|
"""
|
|
if not a.is_contiguous():
|
|
raise ValueError("Matrix `a` must be contiguous.")
|
|
if not b.is_contiguous():
|
|
raise ValueError("Matrix `b` must be contiguous.")
|
|
if not a_s.is_contiguous():
|
|
raise ValueError("Scaling factors `a_s` must be contiguous.")
|
|
if not b_s.is_contiguous():
|
|
raise ValueError("Scaling factors `b_s` must be contiguous.")
|
|
|
|
M, K = a.size()
|
|
K_, N = b.size()
|
|
assert K == K_, "Matrix dimensions do not match for multiplication."
|
|
|
|
c = torch.empty((M, N), dtype=torch.float32)
|
|
grid = lambda meta: (
|
|
triton.cdiv(M, meta["BLOCK_SIZE_M"]),
|
|
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
|
)
|
|
fp8_gemm_kernel[grid](
|
|
a, b, c, a_s, b_s, M, N, K, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128
|
|
)
|
|
return c
|