From ebd889518d742d422f43df2f900b1ed1891971a5 Mon Sep 17 00:00:00 2001 From: sunndy <3617608+sunddytwo@users.noreply.github.com> Date: Mon, 3 Mar 2025 19:38:53 +0800 Subject: [PATCH] Update kernel.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit a.size(-1) : K是a的列数 a.numel//K : M是a的行数 b.size(0) 是行数, b.size(-1)才是b的列数。 这里是求a@b。结果应该是a的行数 X b的列数。N的值应该是b.size(-1) --- inference/kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/kernel.py b/inference/kernel.py index ae907ad..d7930fe 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -184,7 +184,7 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' K = a.size(-1) M = a.numel() // K - N = b.size(0) + N = b.size(-1) c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)