mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-20 02:28:57 -04:00
Merge 1e40f4a73c
into 4cc6253d5c
This commit is contained in:
commit
2a47abbf2c
@ -30,7 +30,7 @@ mapping = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def main(hf_ckpt_path, save_path, n_experts, mp):
|
def main(hf_ckpt_path: str, save_path: str, n_experts: str, mp: int) -> None:
|
||||||
"""
|
"""
|
||||||
Converts and saves model checkpoint files into a specified format.
|
Converts and saves model checkpoint files into a specified format.
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from safetensors.torch import load_file, save_file
|
|||||||
|
|
||||||
from kernel import weight_dequant
|
from kernel import weight_dequant
|
||||||
|
|
||||||
def main(fp8_path, bf16_path):
|
def main(fp8_path: str, bf16_path: str) -> None:
|
||||||
"""
|
"""
|
||||||
Converts FP8 weights to BF16 and saves the converted weights.
|
Converts FP8 weights to BF16 and saves the converted weights.
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ def main(fp8_path, bf16_path):
|
|||||||
fp8_weight_names = []
|
fp8_weight_names = []
|
||||||
|
|
||||||
# Helper function to get tensor from the correct file
|
# Helper function to get tensor from the correct file
|
||||||
def get_tensor(tensor_name):
|
def get_tensor(tensor_name: str) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
|
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from safetensors.torch import load_model
|
|||||||
from model import Transformer, ModelArgs
|
from model import Transformer, ModelArgs
|
||||||
|
|
||||||
|
|
||||||
def sample(logits, temperature: float = 1.0):
|
def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Samples a token from the logits using temperature scaling.
|
Samples a token from the logits using temperature scaling.
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from triton import Config
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
def act_quant_kernel(x_ptr: triton.Pointer, y_ptr: triton.Pointer, s_ptr: triton.Pointer, BLOCK_SIZE: tl.constexpr) -> None:
|
||||||
"""
|
"""
|
||||||
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
def weight_dequant_kernel(x_ptr: tl.pointer, s_ptr: tl.pointer, y_ptr: tl.pointer, M: int, N: int, BLOCK_SIZE: tl.constexpr) -> None:
|
||||||
"""
|
"""
|
||||||
Dequantizes weights using the provided scaling factors and stores the result.
|
Dequantizes weights using the provided scaling factors and stores the result.
|
||||||
|
|
||||||
@ -112,12 +112,12 @@ fp8_gemm_configs = [
|
|||||||
|
|
||||||
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
def fp8_gemm_kernel(a_ptr: tl.tensor, b_ptr: tl.tensor, c_ptr:tl.tensor,
|
||||||
a_s_ptr, b_s_ptr,
|
a_s_ptr: tl.tensor, b_s_ptr: tl.tensor,
|
||||||
M, N: tl.constexpr, K: tl.constexpr,
|
M, N: tl.constexpr, K: tl.constexpr,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr):
|
BLOCK_SIZE_K: tl.constexpr) -> None:
|
||||||
"""
|
"""
|
||||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
||||||
|
|
||||||
@ -167,7 +167,7 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
|||||||
tl.store(c_ptrs, c, mask=mask)
|
tl.store(c_ptrs, c, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform a matrix multiplication using FP8 precision.
|
Perform a matrix multiplication using FP8 precision.
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ class ParallelEmbedding(nn.Module):
|
|||||||
vocab_size (int): Vocabulary size.
|
vocab_size (int): Vocabulary size.
|
||||||
dim (int): Embedding dimension.
|
dim (int): Embedding dimension.
|
||||||
"""
|
"""
|
||||||
def __init__(self, vocab_size: int, dim: int):
|
def __init__(self, vocab_size: int, dim: int) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -173,7 +173,7 @@ class Linear(nn.Module):
|
|||||||
"""
|
"""
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
@ -212,7 +212,7 @@ class ColumnParallelLinear(Linear):
|
|||||||
bias (bool): Whether to include a bias term. Defaults to False.
|
bias (bool): Whether to include a bias term. Defaults to False.
|
||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
|
||||||
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
||||||
self.part_out_features = out_features // world_size
|
self.part_out_features = out_features // world_size
|
||||||
super().__init__(in_features, self.part_out_features, bias, dtype)
|
super().__init__(in_features, self.part_out_features, bias, dtype)
|
||||||
@ -241,7 +241,7 @@ class RowParallelLinear(Linear):
|
|||||||
bias (bool): Whether to include a bias term. Defaults to False.
|
bias (bool): Whether to include a bias term. Defaults to False.
|
||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None) -> None:
|
||||||
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
||||||
self.part_in_features = in_features // world_size
|
self.part_in_features = in_features // world_size
|
||||||
super().__init__(self.part_in_features, out_features, bias, dtype)
|
super().__init__(self.part_in_features, out_features, bias, dtype)
|
||||||
@ -406,7 +406,7 @@ class MLA(nn.Module):
|
|||||||
v_head_dim (int): Dimensionality of value projections.
|
v_head_dim (int): Dimensionality of value projections.
|
||||||
softmax_scale (float): Scaling factor for softmax in attention computation.
|
softmax_scale (float): Scaling factor for softmax in attention computation.
|
||||||
"""
|
"""
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
self.n_heads = args.n_heads
|
self.n_heads = args.n_heads
|
||||||
@ -440,7 +440,7 @@ class MLA(nn.Module):
|
|||||||
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
||||||
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
|
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
|
||||||
|
|
||||||
@ -503,7 +503,7 @@ class MLP(nn.Module):
|
|||||||
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
||||||
w3 (nn.Module): Additional linear layer for feature transformation.
|
w3 (nn.Module): Additional linear layer for feature transformation.
|
||||||
"""
|
"""
|
||||||
def __init__(self, dim: int, inter_dim: int):
|
def __init__(self, dim: int, inter_dim: int) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the MLP layer.
|
Initializes the MLP layer.
|
||||||
|
|
||||||
@ -543,7 +543,7 @@ class Gate(nn.Module):
|
|||||||
weight (torch.nn.Parameter): Learnable weights for the gate.
|
weight (torch.nn.Parameter): Learnable weights for the gate.
|
||||||
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
||||||
"""
|
"""
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the Gate module.
|
Initializes the Gate module.
|
||||||
|
|
||||||
@ -604,7 +604,7 @@ class Expert(nn.Module):
|
|||||||
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
||||||
w3 (nn.Module): Additional linear layer for feature transformation.
|
w3 (nn.Module): Additional linear layer for feature transformation.
|
||||||
"""
|
"""
|
||||||
def __init__(self, dim: int, inter_dim: int):
|
def __init__(self, dim: int, inter_dim: int) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the Expert layer.
|
Initializes the Expert layer.
|
||||||
|
|
||||||
@ -643,7 +643,7 @@ class MoE(nn.Module):
|
|||||||
experts (nn.ModuleList): List of expert modules.
|
experts (nn.ModuleList): List of expert modules.
|
||||||
shared_experts (nn.Module): Shared experts applied to all inputs.
|
shared_experts (nn.Module): Shared experts applied to all inputs.
|
||||||
"""
|
"""
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the MoE module.
|
Initializes the MoE module.
|
||||||
|
|
||||||
@ -700,7 +700,7 @@ class Block(nn.Module):
|
|||||||
attn_norm (nn.Module): Layer normalization for attention.
|
attn_norm (nn.Module): Layer normalization for attention.
|
||||||
ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
||||||
"""
|
"""
|
||||||
def __init__(self, layer_id: int, args: ModelArgs):
|
def __init__(self, layer_id: int, args: ModelArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the Transformer block.
|
Initializes the Transformer block.
|
||||||
|
|
||||||
@ -744,7 +744,7 @@ class Transformer(nn.Module):
|
|||||||
head (nn.Module): Output projection layer mapping to vocabulary size.
|
head (nn.Module): Output projection layer mapping to vocabulary size.
|
||||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
||||||
"""
|
"""
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the Transformer model.
|
Initializes the Transformer model.
|
||||||
|
|
||||||
@ -766,7 +766,7 @@ class Transformer(nn.Module):
|
|||||||
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
|
def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the Transformer model.
|
Forward pass for the Transformer model.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user