add docstring

This commit is contained in:
黄石 2025-04-06 00:31:32 +08:00
parent 0623163343
commit 630769360a

View File

@ -107,6 +107,20 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
@triton.jit
def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Quantizes weights in blocks and computes scaling factors for each block.
Args:
x_ptr (tl.pointer): Pointer to the input weights tensor.
y_ptr (tl.pointer): Pointer to the output buffer for quantized weights.
s_ptr (tl.pointer): Pointer to the output buffer for scaling factors.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
@ -124,6 +138,21 @@ def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
def weight_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes a weight tensor using block-wise quantization.
Args:
x (torch.Tensor): The input weight tensor of shape (M, N) to be quantized.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
Raises:
AssertionError: If `x` is not contiguous or if its dimensions are not 2.
"""
assert x.is_contiguous()
assert x.dim() == 2
M, N = x.size()