mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 01:59:01 -04:00
add docstring
This commit is contained in:
parent
0623163343
commit
630769360a
@ -107,6 +107,20 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||||
|
"""
|
||||||
|
Quantizes weights in blocks and computes scaling factors for each block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_ptr (tl.pointer): Pointer to the input weights tensor.
|
||||||
|
y_ptr (tl.pointer): Pointer to the output buffer for quantized weights.
|
||||||
|
s_ptr (tl.pointer): Pointer to the output buffer for scaling factors.
|
||||||
|
M (int): Number of rows in the weight matrix.
|
||||||
|
N (int): Number of columns in the weight matrix.
|
||||||
|
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
pid_m = tl.program_id(axis=0)
|
pid_m = tl.program_id(axis=0)
|
||||||
pid_n = tl.program_id(axis=1)
|
pid_n = tl.program_id(axis=1)
|
||||||
n = tl.cdiv(N, BLOCK_SIZE)
|
n = tl.cdiv(N, BLOCK_SIZE)
|
||||||
@ -124,6 +138,21 @@ def weight_quant_kernel(x_ptr, y_ptr, s_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|||||||
|
|
||||||
|
|
||||||
def weight_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
def weight_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Quantizes a weight tensor using block-wise quantization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input weight tensor of shape (M, N) to be quantized.
|
||||||
|
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||||
|
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||||
|
- A tensor of scaling factors with dtype `torch.float32`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If `x` is not contiguous or if its dimensions are not 2.
|
||||||
|
"""
|
||||||
assert x.is_contiguous()
|
assert x.is_contiguous()
|
||||||
assert x.dim() == 2
|
assert x.dim() == 2
|
||||||
M, N = x.size()
|
M, N = x.size()
|
||||||
|
Loading…
Reference in New Issue
Block a user