mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 18:18:57 -04:00
Update kernel.py
a.size(-1) : K是a的列数 a.numel//K : M是a的行数 b.size(0) 是行数, b.size(-1)才是b的列数。 这里是求a@b。结果应该是a的行数 X b的列数。N的值应该是b.size(-1)
This commit is contained in:
parent
592fd5daf8
commit
ebd889518d
@ -184,7 +184,7 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
|
|||||||
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
|
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
|
||||||
K = a.size(-1)
|
K = a.size(-1)
|
||||||
M = a.numel() // K
|
M = a.numel() // K
|
||||||
N = b.size(0)
|
N = b.size(-1)
|
||||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
||||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||||
|
Loading…
Reference in New Issue
Block a user