diff --git a/inference/kernel.py b/inference/kernel.py index dec8639..db6a554 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -13,9 +13,9 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): Args: x_ptr (triton.Pointer): Pointer to the input tensor. - y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. - s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. - BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. + y_ptr (triton.Pointer): Pointer to the output tensor for quantized values. + s_ptr (triton.Pointer): Pointer to the output tensor for scaling factors. + BLOCK_SIZE (tl.constexpr): The size of the block to be processed. Returns: None @@ -23,9 +23,12 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448. + + # Calculate scaling factor + s = tl.max(tl.abs(x)) / 448.0 y = x / s y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) tl.store(s_ptr + pid, s) @@ -35,16 +38,22 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor Quantizes the input tensor `x` using block-wise quantization. Args: - x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. - block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + x (torch.Tensor): Input tensor to be quantized. Must be contiguous and last dimension size divisible by `block_size`. + block_size (int, optional): The size of blocks for quantization. Default is 128. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - The quantized tensor with dtype `torch.float8_e4m3fn`. - A tensor of scaling factors with dtype `torch.float32`. + + Raises: + ValueError: If input tensor is not contiguous or last dimension size is not divisible by block size. """ - assert x.is_contiguous() - assert x.size(-1) % block_size == 0 + if not x.is_contiguous(): + raise ValueError("Input tensor must be contiguous.") + if x.size(-1) % block_size != 0: + raise ValueError("Last dimension of input tensor must be divisible by block_size.") + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) @@ -58,9 +67,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): Dequantizes weights using the provided scaling factors and stores the result. Args: - x_ptr (tl.pointer): Pointer to the quantized weights. - s_ptr (tl.pointer): Pointer to the scaling factors. - y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. + x_ptr (triton.Pointer): Pointer to the quantized weights. + s_ptr (triton.Pointer): Pointer to the scaling factors. + y_ptr (triton.Pointer): Pointer to the output buffer for dequantized weights. M (int): Number of rows in the weight matrix. N (int): Number of columns in the weight matrix. BLOCK_SIZE (tl.constexpr): Size of the block for tiling. @@ -74,9 +83,11 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s tl.store(y_ptr + offs, y, mask=mask) @@ -94,17 +105,21 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t torch.Tensor: The dequantized weight tensor of the same shape as `x`. Raises: - AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. + ValueError: If `x` or `s` are not contiguous or if their dimensions are not 2. """ - assert x.is_contiguous() and s.is_contiguous() - assert x.dim() == 2 and s.dim() == 2 + if not x.is_contiguous() or not s.is_contiguous(): + raise ValueError("Both x and s must be contiguous.") + if x.dim() != 2 or s.dim() != 2: + raise ValueError("Both x and s must be 2-dimensional tensors.") + M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']), triton.cdiv(N, meta['BLOCK_SIZE_N'])) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y +# Configuration for autotuning FP8 GEMM implementation fp8_gemm_configs = [ Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] @@ -122,11 +137,11 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, Performs a matrix multiplication operation on FP8 matrices with scaling factors. Args: - a_ptr (tl.tensor): Pointer to the first input matrix A. - b_ptr (tl.tensor): Pointer to the second input matrix B. - c_ptr (tl.tensor): Pointer to the output matrix C. - a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. - b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. + a_ptr (triton.Pointer): Pointer to the first input matrix A. + b_ptr (triton.Pointer): Pointer to the second input matrix B. + c_ptr (triton.Pointer): Pointer to the output matrix C. + a_s_ptr (triton.Pointer): Pointer to the scaling factors for matrix A. + b_s_ptr (triton.Pointer): Pointer to the scaling factors for matrix B. M (int): Number of rows in matrix A and C. N (tl.constexpr): Number of columns in matrix B and C. K (tl.constexpr): Number of columns in matrix A and rows in matrix B. @@ -143,22 +158,29 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] a_s_ptrs = a_s_ptr + offs_m * k b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for i in range(k): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) a_s = tl.load(a_s_ptrs) b_s = tl.load(b_s_ptrs) + + # Update accumulator with scaled dot product accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + + # Move pointers for the next iteration a_ptrs += BLOCK_SIZE_K b_ptrs += BLOCK_SIZE_K a_s_ptrs += 1 b_s_ptrs += 1 + c = accumulator.to(c_ptr.dtype.element_ty) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -167,24 +189,31 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, tl.store(c_ptrs, c, mask=mask) -def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): +def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor: """ - Perform a matrix multiplication using FP8 precision. + Perform matrix multiplication using FP8 precision. Args: - a (torch.Tensor): The first input matrix, must be contiguous. - a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. - b (torch.Tensor): The second input matrix, must be contiguous. - b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. + a (torch.Tensor): First input matrix, must be contiguous. + a_s (torch.Tensor): Scaling factor for the first input matrix, must be contiguous. + b (torch.Tensor): Second input matrix, must be contiguous. + b_s (torch.Tensor): Scaling factor for the second input matrix, must be contiguous. Returns: - torch.Tensor: The result of the matrix multiplication. + torch.Tensor: Result of the matrix multiplication. + + Raises: + ValueError: If any input tensor is not contiguous. """ - assert a.is_contiguous() and b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() + if not a.is_contiguous() or not b.is_contiguous(): + raise ValueError("Matrices a and b must be contiguous.") + if not a_s.is_contiguous() or not b_s.is_contiguous(): + raise ValueError("Scaling factors a_s and b_s must be contiguous.") + K = a.size(-1) M = a.numel() // K N = b.size(0) + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)