diff --git a/inference/kernel.py b/inference/kernel.py index ae907ad..ba18dca 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -87,7 +87,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t Args: x (torch.Tensor): The quantized weight tensor of shape (M, N). - s (torch.Tensor): The scale tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). block_size (int, optional): The block size to use for dequantization. Defaults to 128. Returns: