From 4a65fd9221103ff03864337453c238e65d1f4a1b Mon Sep 17 00:00:00 2001 From: oyzh Date: Sat, 15 Feb 2025 11:02:28 +0800 Subject: [PATCH] fix an args description. --- inference/kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/kernel.py b/inference/kernel.py index ae907ad..ba18dca 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -87,7 +87,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t Args: x (torch.Tensor): The quantized weight tensor of shape (M, N). - s (torch.Tensor): The scale tensor of shape (M, N). + s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size). block_size (int, optional): The block size to use for dequantization. Defaults to 128. Returns: