mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 10:08:59 -04:00
add docstring
This commit is contained in:
parent
0623163343
commit
630769360a
@ -107,6 +107,20 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
||||
|
||||
@triton.jit
|
||||
def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
"""
|
||||
Quantizes weights in blocks and computes scaling factors for each block.
|
||||
|
||||
Args:
|
||||
x_ptr (tl.pointer): Pointer to the input weights tensor.
|
||||
y_ptr (tl.pointer): Pointer to the output buffer for quantized weights.
|
||||
s_ptr (tl.pointer): Pointer to the output buffer for scaling factors.
|
||||
M (int): Number of rows in the weight matrix.
|
||||
N (int): Number of columns in the weight matrix.
|
||||
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
n = tl.cdiv(N, BLOCK_SIZE)
|
||||
@ -124,6 +138,21 @@ def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
|
||||
def weight_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantizes a weight tensor using block-wise quantization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input weight tensor of shape (M, N) to be quantized.
|
||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||
- A tensor of scaling factors with dtype `torch.float32`.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `x` is not contiguous or if its dimensions are not 2.
|
||||
"""
|
||||
assert x.is_contiguous()
|
||||
assert x.dim() == 2
|
||||
M, N = x.size()
|
||||
|
Loading…
Reference in New Issue
Block a user