Update kernel.py

a.size(-1) : K是a的列数
a.numel//K : M是a的行数
b.size(0) 是行数,
b.size(-1)才是b的列数。
这里是求a@b。结果应该是a的行数 X b的列数。N的值应该是b.size(-1)
This commit is contained in:
sunndy 2025-03-03 19:38:53 +08:00 committed by GitHub
parent 592fd5daf8
commit ebd889518d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -184,7 +184,7 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
N = b.size(-1)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)