Update kernel.py

Improved docstrings for better understanding of the functions.
Added specific error messages for input validation.
Kept the structure of the code while making it easier to read and follow.
Ensured that all exceptions provide meaningful messages to the user.
This commit is contained in:
Cristian Cezar Moisés 2025-01-27 23:19:26 -03:00 committed by GitHub
parent ebbbf84d35
commit 18323417c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,9 +13,9 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
Args:
x_ptr (triton.Pointer): Pointer to the input tensor.
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
y_ptr (triton.Pointer): Pointer to the output tensor for quantized values.
s_ptr (triton.Pointer): Pointer to the output tensor for scaling factors.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed.
Returns:
None
@ -23,9 +23,12 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
# Calculate scaling factor
s = tl.max(tl.abs(x)) / 448.0
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
@ -35,16 +38,22 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
x (torch.Tensor): Input tensor to be quantized. Must be contiguous and last dimension size divisible by `block_size`.
block_size (int, optional): The size of blocks for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
Raises:
ValueError: If input tensor is not contiguous or last dimension size is not divisible by block size.
"""
assert x.is_contiguous()
assert x.size(-1) % block_size == 0
if not x.is_contiguous():
raise ValueError("Input tensor must be contiguous.")
if x.size(-1) % block_size != 0:
raise ValueError("Last dimension of input tensor must be divisible by block_size.")
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
@ -58,9 +67,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
x_ptr (triton.Pointer): Pointer to the quantized weights.
s_ptr (triton.Pointer): Pointer to the scaling factors.
y_ptr (triton.Pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
@ -74,9 +83,11 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
@ -94,17 +105,21 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
ValueError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous()
assert x.dim() == 2 and s.dim() == 2
if not x.is_contiguous() or not s.is_contiguous():
raise ValueError("Both x and s must be contiguous.")
if x.dim() != 2 or s.dim() != 2:
raise ValueError("Both x and s must be 2-dimensional tensors.")
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
# Configuration for autotuning FP8 GEMM implementation
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
@ -122,11 +137,11 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
a_ptr (triton.Pointer): Pointer to the first input matrix A.
b_ptr (triton.Pointer): Pointer to the second input matrix B.
c_ptr (triton.Pointer): Pointer to the output matrix C.
a_s_ptr (triton.Pointer): Pointer to the scaling factors for matrix A.
b_s_ptr (triton.Pointer): Pointer to the scaling factors for matrix B.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
@ -143,22 +158,29 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
# Update accumulator with scaled dot product
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
# Move pointers for the next iteration
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@ -167,24 +189,31 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
tl.store(c_ptrs, c, mask=mask)
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
"""
Perform a matrix multiplication using FP8 precision.
Perform matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
a (torch.Tensor): First input matrix, must be contiguous.
a_s (torch.Tensor): Scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): Second input matrix, must be contiguous.
b_s (torch.Tensor): Scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
torch.Tensor: Result of the matrix multiplication.
Raises:
ValueError: If any input tensor is not contiguous.
"""
assert a.is_contiguous() and b.is_contiguous()
assert a_s.is_contiguous() and b_s.is_contiguous()
if not a.is_contiguous() or not b.is_contiguous():
raise ValueError("Matrices a and b must be contiguous.")
if not a_s.is_contiguous() or not b_s.is_contiguous():
raise ValueError("Scaling factors a_s and b_s must be contiguous.")
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)