mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-22 21:58:58 -05:00
Update kernel.py
This commit is contained in:
parent
5ee97a83f0
commit
f10ff9c262
@ -106,7 +106,7 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
|||||||
|
|
||||||
|
|
||||||
fp8_gemm_configs = [
|
fp8_gemm_configs = [
|
||||||
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128, "GROUP_SIZE_M": 8}, num_stages=num_stages, num_warps=8)
|
||||||
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -137,9 +137,32 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
|||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
pid_m = tl.program_id(axis=0)
|
|
||||||
pid_n = tl.program_id(axis=1)
|
"""Kernel for computing the matmul C = A x B.
|
||||||
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||||
|
"""
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Map program ids `pid` to the block of C it should compute.
|
||||||
|
# This is done in a grouped ordering to promote L2 data reuse.
|
||||||
|
# See above `L2 Cache Optimizations` section for details.
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
k = tl.cdiv(K, BLOCK_SIZE_K)
|
k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
# Create pointers for the first blocks of A and B.
|
||||||
|
# We will advance this pointer as we move in the K direction
|
||||||
|
# and accumulate
|
||||||
|
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||||
|
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||||
|
# See above `Pointer Arithmetic` section for details
|
||||||
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||||
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
@ -186,6 +209,7 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
|
|||||||
M = a.numel() // K
|
M = a.numel() // K
|
||||||
N = b.size(0)
|
N = b.size(0)
|
||||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
# 1D launch kernel where each block gets its own program.
|
||||||
|
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
|
||||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||||
return c
|
return c
|
||||||
|
Loading…
Reference in New Issue
Block a user