From a7bab5c920fb96b9adffcd913089019251d1f971 Mon Sep 17 00:00:00 2001 From: Gabriel Caetano Date: Tue, 8 Apr 2025 22:33:48 -0300 Subject: [PATCH] Clean up and optimize Triton FP8 kernels - Improved readability and structure of Triton kernels for FP8 weight dequantization and matrix multiplication (GEMM) - Added comments for clarity - Replaced hardcoded block sizes with configurable parameters - Improved safety using tl.cdiv and masking - Renamed variables and ensured consistency in naming --- inference/kernel.py | 131 ++++++++++++++++++++++++-------------------- 1 file changed, 72 insertions(+), 59 deletions(-) diff --git a/inference/kernel.py b/inference/kernel.py index a71dbe7..5c72970 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -2,35 +2,40 @@ import torch import triton import triton.language as tl + +@triton.jit def weight_dequant_kernel( - q_ptr, s_ptr, out_ptr, M, N, K, - stride_qm, stride_qk, stride_sm, stride_sn, + q_ptr, s_ptr, out_ptr, M, N, + stride_qm, stride_qn, stride_sm, stride_sn, stride_om, stride_on, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr ): """ - Kernel para desquantização de pesos FP8. + Triton kernel for FP8 weight dequantization. + out = q * s """ pid = tl.program_id(axis=0) - pid_m = pid // (N // BLOCK_SIZE_N) - pid_n = pid % (N // BLOCK_SIZE_N) - + num_blocks_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_blocks_n + pid_n = pid % num_blocks_n + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - + mask_m = offs_m < M mask_n = offs_n < N - - q_ptrs = q_ptr + offs_m[:, None] * stride_qm + offs_n[None, :] * stride_qk + + q_ptrs = q_ptr + offs_m[:, None] * stride_qm + offs_n[None, :] * stride_qn s_ptrs = s_ptr + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on - + q = tl.load(q_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0) s = tl.load(s_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=1) - + out = q.to(tl.float32) * s.to(tl.float32) tl.store(out_ptrs, out, mask=mask_m[:, None] & mask_n[None, :]) + @triton.jit def fp8_gemm_kernel( a_ptr, b_ptr, c_ptr, M, N, K, @@ -39,68 +44,76 @@ def fp8_gemm_kernel( BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr ): """ - Kernel para multiplicação de matrizes com FP8. + Triton kernel for FP8 GEMM (General Matrix Multiply) + c = a @ b """ pid = tl.program_id(axis=0) - pid_m = pid // (N // BLOCK_SIZE_N) - pid_n = pid % (N // BLOCK_SIZE_N) - + num_blocks_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_blocks_n + pid_n = pid % num_blocks_n + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - + mask_m = offs_m < M mask_n = offs_n < N - - a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak - b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn - c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, K, BLOCK_SIZE_K): - a = tl.load(a_ptrs, mask=mask_m[:, None], other=0) - b = tl.load(b_ptrs, mask=mask_n[None, :], other=0) - accumulator += tl.dot(a, b) - - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - tl.store(c_ptrs, accumulator, mask=mask_m[:, None] & mask_n[None, :]) -def dequantize_weights(q_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, K, BLOCK_SIZE_K): + offs_k = k + tl.arange(0, BLOCK_SIZE_K) + mask_k = offs_k < K + + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + + a = tl.load(a_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0) + b = tl.load(b_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0) + + acc += tl.dot(a, b) + + tl.store(c_ptrs, acc, mask=mask_m[:, None] & mask_n[None, :]) + + +def dequantize_weights(q_weight: torch.Tensor, scale: torch.Tensor, block_size=16) -> torch.Tensor: """ - Função para desquantizar pesos FP8 com segurança. + Dequantizes FP8 weights with scaling. """ - assert q_weight.shape == scale.shape, "Dimensões incompatíveis entre peso quantizado e escala." - - out = torch.empty_like(q_weight, dtype=torch.float32) - weight_dequant_kernel[ - (q_weight.shape[0] // 16, q_weight.shape[1] // 16) - ]( - q_weight, scale, out, - q_weight.shape[0], q_weight.shape[1], q_weight.shape[1], + assert q_weight.shape == scale.shape, "Mismatched shapes between quantized weights and scales." + + M, N = q_weight.shape + output = torch.empty_like(q_weight, dtype=torch.float32) + + grid = (triton.cdiv(M, block_size) * triton.cdiv(N, block_size),) + weight_dequant_kernel[grid]( + q_weight, scale, output, + M, N, q_weight.stride(0), q_weight.stride(1), scale.stride(0), scale.stride(1), - out.stride(0), out.stride(1), - 16, 16, 16 + output.stride(0), output.stride(1), + block_size, block_size ) - return out + return output -def fp8_gemm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + +def fp8_gemm(a: torch.Tensor, b: torch.Tensor, block_size=16) -> torch.Tensor: """ - Multiplicação de matrizes FP8 segura e eficiente. + Performs FP8 GEMM (a @ b) with Triton. """ - assert a.shape[1] == b.shape[0], "Dimensões incompatíveis para multiplicação de matrizes." - - c = torch.empty((a.shape[0], b.shape[1]), dtype=torch.float32) - fp8_gemm_kernel[ - (a.shape[0] // 16, b.shape[1] // 16) - ]( - a, b, c, - a.shape[0], b.shape[1], a.shape[1], + assert a.shape[1] == b.shape[0], "Incompatible matrix dimensions." + + M, K = a.shape + _, N = b.shape + output = torch.empty((M, N), dtype=torch.float32) + + grid = (triton.cdiv(M, block_size) * triton.cdiv(N, block_size),) + fp8_gemm_kernel[grid]( + a, b, output, + M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - 16, 16, 16 + output.stride(0), output.stride(1), + block_size, block_size, block_size ) - return c + return output