mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 10:08:59 -04:00
Clean up and optimize Triton FP8 kernels
- Improved readability and structure of Triton kernels for FP8 weight dequantization and matrix multiplication (GEMM) - Added comments for clarity - Replaced hardcoded block sizes with configurable parameters - Improved safety using tl.cdiv and masking - Renamed variables and ensured consistency in naming
This commit is contained in:
parent
61790e1653
commit
a7bab5c920
@ -2,18 +2,22 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
def weight_dequant_kernel(
|
def weight_dequant_kernel(
|
||||||
q_ptr, s_ptr, out_ptr, M, N, K,
|
q_ptr, s_ptr, out_ptr, M, N,
|
||||||
stride_qm, stride_qk, stride_sm, stride_sn,
|
stride_qm, stride_qn, stride_sm, stride_sn,
|
||||||
stride_om, stride_on,
|
stride_om, stride_on,
|
||||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Kernel para desquantização de pesos FP8.
|
Triton kernel for FP8 weight dequantization.
|
||||||
|
out = q * s
|
||||||
"""
|
"""
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
pid_m = pid // (N // BLOCK_SIZE_N)
|
num_blocks_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
pid_n = pid % (N // BLOCK_SIZE_N)
|
pid_m = pid // num_blocks_n
|
||||||
|
pid_n = pid % num_blocks_n
|
||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
@ -21,7 +25,7 @@ def weight_dequant_kernel(
|
|||||||
mask_m = offs_m < M
|
mask_m = offs_m < M
|
||||||
mask_n = offs_n < N
|
mask_n = offs_n < N
|
||||||
|
|
||||||
q_ptrs = q_ptr + offs_m[:, None] * stride_qm + offs_n[None, :] * stride_qk
|
q_ptrs = q_ptr + offs_m[:, None] * stride_qm + offs_n[None, :] * stride_qn
|
||||||
s_ptrs = s_ptr + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
|
s_ptrs = s_ptr + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn
|
||||||
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||||
|
|
||||||
@ -31,6 +35,7 @@ def weight_dequant_kernel(
|
|||||||
out = q.to(tl.float32) * s.to(tl.float32)
|
out = q.to(tl.float32) * s.to(tl.float32)
|
||||||
tl.store(out_ptrs, out, mask=mask_m[:, None] & mask_n[None, :])
|
tl.store(out_ptrs, out, mask=mask_m[:, None] & mask_n[None, :])
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fp8_gemm_kernel(
|
def fp8_gemm_kernel(
|
||||||
a_ptr, b_ptr, c_ptr, M, N, K,
|
a_ptr, b_ptr, c_ptr, M, N, K,
|
||||||
@ -39,68 +44,76 @@ def fp8_gemm_kernel(
|
|||||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Kernel para multiplicação de matrizes com FP8.
|
Triton kernel for FP8 GEMM (General Matrix Multiply)
|
||||||
|
c = a @ b
|
||||||
"""
|
"""
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
pid_m = pid // (N // BLOCK_SIZE_N)
|
num_blocks_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
pid_n = pid % (N // BLOCK_SIZE_N)
|
pid_m = pid // num_blocks_n
|
||||||
|
pid_n = pid % num_blocks_n
|
||||||
|
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
|
|
||||||
mask_m = offs_m < M
|
mask_m = offs_m < M
|
||||||
mask_n = offs_n < N
|
mask_n = offs_n < N
|
||||||
|
|
||||||
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
|
||||||
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
|
||||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
|
||||||
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
||||||
for k in range(0, K, BLOCK_SIZE_K):
|
for k in range(0, K, BLOCK_SIZE_K):
|
||||||
a = tl.load(a_ptrs, mask=mask_m[:, None], other=0)
|
offs_k = k + tl.arange(0, BLOCK_SIZE_K)
|
||||||
b = tl.load(b_ptrs, mask=mask_n[None, :], other=0)
|
mask_k = offs_k < K
|
||||||
accumulator += tl.dot(a, b)
|
|
||||||
|
|
||||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||||
|
|
||||||
tl.store(c_ptrs, accumulator, mask=mask_m[:, None] & mask_n[None, :])
|
a = tl.load(a_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0)
|
||||||
|
b = tl.load(b_ptrs, mask=mask_k[:, None] & mask_n[None, :], other=0)
|
||||||
|
|
||||||
def dequantize_weights(q_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
acc += tl.dot(a, b)
|
||||||
|
|
||||||
|
tl.store(c_ptrs, acc, mask=mask_m[:, None] & mask_n[None, :])
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize_weights(q_weight: torch.Tensor, scale: torch.Tensor, block_size=16) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Função para desquantizar pesos FP8 com segurança.
|
Dequantizes FP8 weights with scaling.
|
||||||
"""
|
"""
|
||||||
assert q_weight.shape == scale.shape, "Dimensões incompatíveis entre peso quantizado e escala."
|
assert q_weight.shape == scale.shape, "Mismatched shapes between quantized weights and scales."
|
||||||
|
|
||||||
out = torch.empty_like(q_weight, dtype=torch.float32)
|
M, N = q_weight.shape
|
||||||
weight_dequant_kernel[
|
output = torch.empty_like(q_weight, dtype=torch.float32)
|
||||||
(q_weight.shape[0] // 16, q_weight.shape[1] // 16)
|
|
||||||
](
|
grid = (triton.cdiv(M, block_size) * triton.cdiv(N, block_size),)
|
||||||
q_weight, scale, out,
|
weight_dequant_kernel[grid](
|
||||||
q_weight.shape[0], q_weight.shape[1], q_weight.shape[1],
|
q_weight, scale, output,
|
||||||
|
M, N,
|
||||||
q_weight.stride(0), q_weight.stride(1),
|
q_weight.stride(0), q_weight.stride(1),
|
||||||
scale.stride(0), scale.stride(1),
|
scale.stride(0), scale.stride(1),
|
||||||
out.stride(0), out.stride(1),
|
output.stride(0), output.stride(1),
|
||||||
16, 16, 16
|
block_size, block_size
|
||||||
)
|
)
|
||||||
return out
|
return output
|
||||||
|
|
||||||
def fp8_gemm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Multiplicação de matrizes FP8 segura e eficiente.
|
|
||||||
"""
|
|
||||||
assert a.shape[1] == b.shape[0], "Dimensões incompatíveis para multiplicação de matrizes."
|
|
||||||
|
|
||||||
c = torch.empty((a.shape[0], b.shape[1]), dtype=torch.float32)
|
def fp8_gemm(a: torch.Tensor, b: torch.Tensor, block_size=16) -> torch.Tensor:
|
||||||
fp8_gemm_kernel[
|
"""
|
||||||
(a.shape[0] // 16, b.shape[1] // 16)
|
Performs FP8 GEMM (a @ b) with Triton.
|
||||||
](
|
"""
|
||||||
a, b, c,
|
assert a.shape[1] == b.shape[0], "Incompatible matrix dimensions."
|
||||||
a.shape[0], b.shape[1], a.shape[1],
|
|
||||||
|
M, K = a.shape
|
||||||
|
_, N = b.shape
|
||||||
|
output = torch.empty((M, N), dtype=torch.float32)
|
||||||
|
|
||||||
|
grid = (triton.cdiv(M, block_size) * triton.cdiv(N, block_size),)
|
||||||
|
fp8_gemm_kernel[grid](
|
||||||
|
a, b, output,
|
||||||
|
M, N, K,
|
||||||
a.stride(0), a.stride(1),
|
a.stride(0), a.stride(1),
|
||||||
b.stride(0), b.stride(1),
|
b.stride(0), b.stride(1),
|
||||||
c.stride(0), c.stride(1),
|
output.stride(0), output.stride(1),
|
||||||
16, 16, 16
|
block_size, block_size, block_size
|
||||||
)
|
)
|
||||||
return c
|
return output
|
||||||
|
Loading…
Reference in New Issue
Block a user