diff --git a/inference/kernel.py b/inference/kernel.py index ae907ad..d7930fe 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -184,7 +184,7 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' K = a.size(-1) M = a.numel() // K - N = b.size(0) + N = b.size(-1) c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)