diff --git a/inference/kernel.py b/inference/kernel.py index b90edf8..f2277dc 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -107,6 +107,20 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t @triton.jit def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Quantizes weights in blocks and computes scaling factors for each block. + + Args: + x_ptr (tl.pointer): Pointer to the input weights tensor. + y_ptr (tl.pointer): Pointer to the output buffer for quantized weights. + s_ptr (tl.pointer): Pointer to the output buffer for scaling factors. + M (int): Number of rows in the weight matrix. + N (int): Number of columns in the weight matrix. + BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + + Returns: + None + """ pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) n = tl.cdiv(N, BLOCK_SIZE) @@ -124,6 +138,21 @@ def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr): def weight_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantizes a weight tensor using block-wise quantization. + + Args: + x (torch.Tensor): The input weight tensor of shape (M, N) to be quantized. + block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The quantized tensor with dtype `torch.float8_e4m3fn`. + - A tensor of scaling factors with dtype `torch.float32`. + + Raises: + AssertionError: If `x` is not contiguous or if its dimensions are not 2. + """ assert x.is_contiguous() assert x.dim() == 2 M, N = x.size()