mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-22 05:38:59 -05:00
Merge pull request #193 from enochkan/main
Add docstrings to functions in inference modules for better clarity
This commit is contained in:
commit
fdbd5be754
6
.gitignore
vendored
6
.gitignore
vendored
@ -165,4 +165,8 @@ cython_debug/
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
#.idea/
|
||||
|
||||
.vscode/*
|
||||
|
||||
.DS_Store
|
@ -31,6 +31,18 @@ mapping = {
|
||||
|
||||
|
||||
def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
"""
|
||||
Converts and saves model checkpoint files into a specified format.
|
||||
|
||||
Args:
|
||||
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
||||
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
||||
n_experts (int): Total number of experts in the model.
|
||||
mp (int): Model parallelism factor.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
torch.set_num_threads(8)
|
||||
n_local_experts = n_experts // mp
|
||||
state_dicts = [{} for _ in range(mp)]
|
||||
|
@ -10,6 +10,25 @@ from safetensors.torch import load_file, save_file
|
||||
from kernel import weight_dequant
|
||||
|
||||
def main(fp8_path, bf16_path):
|
||||
"""
|
||||
Converts FP8 weights to BF16 and saves the converted weights.
|
||||
|
||||
This function reads FP8 weights from the specified directory, converts them to BF16,
|
||||
and saves the converted weights to another specified directory. It also updates the
|
||||
model index file to reflect the changes.
|
||||
|
||||
Args:
|
||||
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
|
||||
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
|
||||
|
||||
Raises:
|
||||
KeyError: If a required scale_inv tensor is missing for a weight.
|
||||
|
||||
Notes:
|
||||
- The function assumes that the FP8 weights are stored in safetensor files.
|
||||
- The function caches loaded safetensor files to optimize memory usage.
|
||||
- The function updates the model index file to remove references to scale_inv tensors.
|
||||
"""
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
os.makedirs(bf16_path, exist_ok=True)
|
||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||
@ -23,6 +42,18 @@ def main(fp8_path, bf16_path):
|
||||
|
||||
# Helper function to get tensor from the correct file
|
||||
def get_tensor(tensor_name):
|
||||
"""
|
||||
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
|
||||
|
||||
Args:
|
||||
tensor_name (str): The name of the tensor to retrieve.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The retrieved tensor.
|
||||
|
||||
Raises:
|
||||
KeyError: If the tensor does not exist in the safetensor file.
|
||||
"""
|
||||
file_name = weight_map[tensor_name]
|
||||
if file_name not in loaded_files:
|
||||
file_path = os.path.join(fp8_path, file_name)
|
||||
|
@ -12,6 +12,16 @@ from model import Transformer, ModelArgs
|
||||
|
||||
|
||||
def sample(logits, temperature: float = 1.0):
|
||||
"""
|
||||
Samples a token from the logits using temperature scaling.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The logits tensor for token predictions.
|
||||
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The sampled token.
|
||||
"""
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||
@ -25,6 +35,19 @@ def generate(
|
||||
eos_id: int,
|
||||
temperature: float = 1.0
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Generates new tokens based on the given prompt tokens using the specified model.
|
||||
|
||||
Args:
|
||||
model (Transformer): The transformer model used for token generation.
|
||||
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
eos_id (int): The end-of-sequence token ID.
|
||||
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
||||
"""
|
||||
prompt_lens = [len(t) for t in prompt_tokens]
|
||||
assert max(prompt_lens) <= model.max_seq_len
|
||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||
@ -63,6 +86,17 @@ def main(
|
||||
max_new_tokens: int = 100,
|
||||
temperature: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Main function to load the model and perform interactive or batch text generation.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the model checkpoint directory.
|
||||
config (str): Path to the model configuration file.
|
||||
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
|
||||
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
|
||||
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
||||
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
||||
"""
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
@ -125,6 +159,20 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Command-line interface for distributed text generation.
|
||||
|
||||
Arguments:
|
||||
--ckpt-path (str): Path to the model checkpoint directory.
|
||||
--config (str): Path to the model configuration file.
|
||||
--input-file (str, optional): File containing prompts for batch processing.
|
||||
--interactive (bool, optional): Enable interactive mode for generating text.
|
||||
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
|
||||
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If neither input-file nor interactive mode is specified.
|
||||
"""
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--ckpt-path", type=str, required=True)
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
|
@ -8,6 +8,18 @@ from triton import Config
|
||||
|
||||
@triton.jit
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
"""
|
||||
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
||||
|
||||
Args:
|
||||
x_ptr (triton.Pointer): Pointer to the input tensor.
|
||||
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
|
||||
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
|
||||
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||
@ -19,6 +31,18 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantizes the input tensor `x` using block-wise quantization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||
- A tensor of scaling factors with dtype `torch.float32`.
|
||||
"""
|
||||
assert x.is_contiguous()
|
||||
assert x.size(-1) % block_size == 0
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
@ -30,6 +54,20 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
||||
|
||||
@triton.jit
|
||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
"""
|
||||
Dequantizes weights using the provided scaling factors and stores the result.
|
||||
|
||||
Args:
|
||||
x_ptr (tl.pointer): Pointer to the quantized weights.
|
||||
s_ptr (tl.pointer): Pointer to the scaling factors.
|
||||
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
|
||||
M (int): Number of rows in the weight matrix.
|
||||
N (int): Number of columns in the weight matrix.
|
||||
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
n = tl.cdiv(N, BLOCK_SIZE)
|
||||
@ -44,6 +82,20 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
"""
|
||||
Dequantizes the given weight tensor using the provided scale tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The quantized weight tensor of shape (M, N).
|
||||
s (torch.Tensor): The scale tensor of shape (M, N).
|
||||
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
||||
"""
|
||||
assert x.is_contiguous() and s.is_contiguous()
|
||||
assert x.dim() == 2 and s.dim() == 2
|
||||
M, N = x.size()
|
||||
@ -66,6 +118,25 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr):
|
||||
"""
|
||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
||||
|
||||
Args:
|
||||
a_ptr (tl.tensor): Pointer to the first input matrix A.
|
||||
b_ptr (tl.tensor): Pointer to the second input matrix B.
|
||||
c_ptr (tl.tensor): Pointer to the output matrix C.
|
||||
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
|
||||
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
|
||||
M (int): Number of rows in matrix A and C.
|
||||
N (tl.constexpr): Number of columns in matrix B and C.
|
||||
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
|
||||
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
|
||||
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
|
||||
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
@ -97,6 +168,18 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
|
||||
|
||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
||||
"""
|
||||
Perform a matrix multiplication using FP8 precision.
|
||||
|
||||
Args:
|
||||
a (torch.Tensor): The first input matrix, must be contiguous.
|
||||
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
|
||||
b (torch.Tensor): The second input matrix, must be contiguous.
|
||||
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of the matrix multiplication.
|
||||
"""
|
||||
assert a.is_contiguous() and b.is_contiguous()
|
||||
assert a_s.is_contiguous() and b_s.is_contiguous()
|
||||
K = a.size(-1)
|
||||
|
@ -18,6 +18,39 @@ attn_impl: Literal["naive", "absorb"] = "absorb"
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
"""
|
||||
Data class for defining model arguments and hyperparameters.
|
||||
|
||||
Attributes:
|
||||
max_batch_size (int): Maximum batch size.
|
||||
max_seq_len (int): Maximum sequence length.
|
||||
dtype (Literal["bf16", "fp8"]): Data type for computations.
|
||||
vocab_size (int): Vocabulary size.
|
||||
dim (int): Model dimension.
|
||||
inter_dim (int): Intermediate dimension for MLP layers.
|
||||
moe_inter_dim (int): Intermediate dimension for MoE layers.
|
||||
n_layers (int): Number of transformer layers.
|
||||
n_dense_layers (int): Number of dense layers in the model.
|
||||
n_heads (int): Number of attention heads.
|
||||
n_routed_experts (int): Number of routed experts for MoE layers.
|
||||
n_shared_experts (int): Number of shared experts for MoE layers.
|
||||
n_activated_experts (int): Number of activated experts in MoE layers.
|
||||
n_expert_groups (int): Number of expert groups.
|
||||
n_limited_groups (int): Number of limited groups for MoE routing.
|
||||
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
|
||||
route_scale (float): Scaling factor for routing scores.
|
||||
q_lora_rank (int): LoRA rank for query projections.
|
||||
kv_lora_rank (int): LoRA rank for key-value projections.
|
||||
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
|
||||
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
|
||||
v_head_dim (int): Dimension for value projections.
|
||||
original_seq_len (int): Original sequence length.
|
||||
rope_theta (float): Base for rotary positional encoding.
|
||||
rope_factor (float): Scaling factor for extended sequence lengths.
|
||||
beta_fast (int): Fast beta correction factor.
|
||||
beta_slow (int): Slow beta correction factor.
|
||||
mscale (float): Scaling factor for extended attention.
|
||||
"""
|
||||
max_batch_size: int = 8
|
||||
max_seq_len: int = 4096 * 4
|
||||
dtype: Literal["bf16", "fp8"] = "bf16"
|
||||
@ -52,6 +85,13 @@ class ModelArgs:
|
||||
|
||||
|
||||
class ParallelEmbedding(nn.Module):
|
||||
"""
|
||||
Embedding layer with parallelism support across distributed processes.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Vocabulary size.
|
||||
dim (int): Embedding dimension.
|
||||
"""
|
||||
def __init__(self, vocab_size: int, dim: int):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
@ -63,6 +103,18 @@ class ParallelEmbedding(nn.Module):
|
||||
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for parallel embedding layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor containing token indices.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Embedded representations.
|
||||
|
||||
Raises:
|
||||
ValueError: If `world_size` is not defined.
|
||||
"""
|
||||
if world_size > 1:
|
||||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
||||
x = x - self.vocab_start_idx
|
||||
@ -75,6 +127,27 @@ class ParallelEmbedding(nn.Module):
|
||||
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Applies a linear transformation to the incoming data: y = xA^T + b.
|
||||
This function supports specialized implementations based on quantization
|
||||
and tensor formats.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
weight (torch.Tensor): The weight tensor. It may be quantized and
|
||||
requires dequantization for certain cases.
|
||||
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of the linear transformation, which may involve
|
||||
quantization-aware computations depending on the input parameters.
|
||||
|
||||
Notes:
|
||||
- If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version
|
||||
is used for computation.
|
||||
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
||||
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
||||
"""
|
||||
if weight.element_size() > 1:
|
||||
return F.linear(x, weight, bias)
|
||||
elif gemm_impl == "bf16":
|
||||
@ -89,6 +162,15 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""
|
||||
Custom linear layer with support for quantized weights and optional bias.
|
||||
|
||||
Args:
|
||||
in_features (int): Number of input features.
|
||||
out_features (int): Number of output features.
|
||||
bias (bool): Whether to include a bias term. Defaults to False.
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
@ -108,27 +190,72 @@ class Linear(nn.Module):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the custom linear layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor after linear computation.
|
||||
"""
|
||||
return linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class ColumnParallelLinear(Linear):
|
||||
"""
|
||||
Linear layer with column parallelism, splitting output features across distributed processes.
|
||||
|
||||
Args:
|
||||
in_features (int): Number of input features.
|
||||
out_features (int): Total number of output features.
|
||||
bias (bool): Whether to include a bias term. Defaults to False.
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
assert out_features % world_size == 0
|
||||
self.part_out_features = out_features // world_size
|
||||
super().__init__(in_features, self.part_out_features, bias, dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for column parallel linear layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with column-parallel computation.
|
||||
"""
|
||||
y = linear(x, self.weight, self.bias)
|
||||
return y
|
||||
|
||||
|
||||
class RowParallelLinear(Linear):
|
||||
"""
|
||||
Linear layer with row parallelism, splitting input features across distributed processes.
|
||||
|
||||
Args:
|
||||
in_features (int): Total number of input features.
|
||||
out_features (int): Number of output features.
|
||||
bias (bool): Whether to include a bias term. Defaults to False.
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
assert in_features % world_size == 0
|
||||
self.part_in_features = in_features // world_size
|
||||
super().__init__(self.part_in_features, out_features, bias, dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for row parallel linear layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor with row-parallel computation.
|
||||
"""
|
||||
y = linear(x, self.weight)
|
||||
if world_size > 1:
|
||||
dist.all_reduce(y)
|
||||
@ -138,6 +265,13 @@ class RowParallelLinear(Linear):
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""
|
||||
Root Mean Square Layer Normalization (RMSNorm).
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the input tensor.
|
||||
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
|
||||
"""
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@ -145,10 +279,28 @@ class RMSNorm(nn.Module):
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Forward pass for RMSNorm.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Normalized tensor with the same shape as input.
|
||||
"""
|
||||
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
||||
|
||||
|
||||
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||
"""
|
||||
Precomputes frequency-based complex exponential values for rotary positional embeddings.
|
||||
|
||||
Args:
|
||||
args (ModelArgs): Model arguments containing positional embedding parameters.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Precomputed complex exponential values for positional embeddings.
|
||||
"""
|
||||
dim = args.qk_rope_head_dim
|
||||
seqlen = args.max_seq_len
|
||||
beta_fast = args.beta_fast
|
||||
@ -157,14 +309,51 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||
factor = args.rope_factor
|
||||
|
||||
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
||||
"""
|
||||
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
|
||||
|
||||
Args:
|
||||
num_rotations (float): Number of rotations to compute the correction for.
|
||||
dim (int): Dimensionality of the embedding space.
|
||||
base (float): Base value for the exponential computation.
|
||||
max_seq_len (int): Maximum sequence length.
|
||||
|
||||
Returns:
|
||||
float: The correction dimension based on the input parameters.
|
||||
"""
|
||||
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
||||
|
||||
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
||||
"""
|
||||
Computes the range of correction dimensions for rotary positional embeddings.
|
||||
|
||||
Args:
|
||||
low_rot (float): Lower bound for the number of rotations.
|
||||
high_rot (float): Upper bound for the number of rotations.
|
||||
dim (int): Dimensionality of the embedding space.
|
||||
base (float): Base value for the exponential computation.
|
||||
max_seq_len (int): Maximum sequence length.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
|
||||
"""
|
||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
||||
return max(low, 0), min(high, dim-1)
|
||||
|
||||
def linear_ramp_factor(min, max, dim):
|
||||
"""
|
||||
Computes a linear ramp function used to smooth values between a minimum and maximum range.
|
||||
|
||||
Args:
|
||||
min (float): Minimum value for the ramp function.
|
||||
max (float): Maximum value for the ramp function.
|
||||
dim (int): Dimensionality of the ramp tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
|
||||
clamped to the range [0, 1].
|
||||
"""
|
||||
if min == max:
|
||||
max += 0.001
|
||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
||||
@ -184,6 +373,16 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies rotary positional embeddings to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with positional embeddings to be applied.
|
||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor with rotary embeddings applied.
|
||||
"""
|
||||
dtype = x.dtype
|
||||
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
||||
@ -192,6 +391,21 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class MLA(nn.Module):
|
||||
"""
|
||||
Multi-Headed Attention Layer (MLA).
|
||||
|
||||
Attributes:
|
||||
dim (int): Dimensionality of the input features.
|
||||
n_heads (int): Number of attention heads.
|
||||
n_local_heads (int): Number of local attention heads for distributed systems.
|
||||
q_lora_rank (int): Rank for low-rank query projection.
|
||||
kv_lora_rank (int): Rank for low-rank key/value projection.
|
||||
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
|
||||
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
|
||||
qk_head_dim (int): Total dimensionality of query/key projections.
|
||||
v_head_dim (int): Dimensionality of value projections.
|
||||
softmax_scale (float): Scaling factor for softmax in attention computation.
|
||||
"""
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
@ -227,6 +441,18 @@ class MLA(nn.Module):
|
||||
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
||||
"""
|
||||
Forward pass for the Multi-Headed Attention Layer (MLA).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
|
||||
start_pos (int): Starting position in the sequence for caching.
|
||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
||||
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as the input.
|
||||
"""
|
||||
bsz, seqlen, _ = x.size()
|
||||
end_pos = start_pos + seqlen
|
||||
if self.q_lora_rank == 0:
|
||||
@ -269,18 +495,61 @@ class MLA(nn.Module):
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Multi-Layer Perceptron (MLP) used as a feed-forward layer.
|
||||
|
||||
Attributes:
|
||||
w1 (nn.Module): Linear layer for input-to-hidden transformation.
|
||||
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
||||
w3 (nn.Module): Additional linear layer for feature transformation.
|
||||
"""
|
||||
def __init__(self, dim: int, inter_dim: int):
|
||||
"""
|
||||
Initializes the MLP layer.
|
||||
|
||||
Args:
|
||||
dim (int): Input and output dimensionality.
|
||||
inter_dim (int): Hidden layer dimensionality.
|
||||
"""
|
||||
super().__init__()
|
||||
self.w1 = ColumnParallelLinear(dim, inter_dim)
|
||||
self.w2 = RowParallelLinear(inter_dim, dim)
|
||||
self.w3 = ColumnParallelLinear(dim, inter_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the MLP layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after MLP computation.
|
||||
"""
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class Gate(nn.Module):
|
||||
"""
|
||||
Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
|
||||
|
||||
Attributes:
|
||||
dim (int): Dimensionality of input features.
|
||||
topk (int): Number of top experts activated for each input.
|
||||
n_groups (int): Number of groups for routing.
|
||||
topk_groups (int): Number of groups to route inputs to.
|
||||
score_func (str): Scoring function ('softmax' or 'sigmoid').
|
||||
route_scale (float): Scaling factor for routing weights.
|
||||
weight (torch.nn.Parameter): Learnable weights for the gate.
|
||||
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
||||
"""
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""
|
||||
Initializes the Gate module.
|
||||
|
||||
Args:
|
||||
args (ModelArgs): Model arguments containing gating parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
self.topk = args.n_activated_experts
|
||||
@ -292,6 +561,15 @@ class Gate(nn.Module):
|
||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass for the gating mechanism.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
|
||||
"""
|
||||
scores = linear(x, self.weight)
|
||||
if self.score_func == "softmax":
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
||||
@ -318,18 +596,60 @@ class Gate(nn.Module):
|
||||
|
||||
|
||||
class Expert(nn.Module):
|
||||
"""
|
||||
Expert layer for Mixture-of-Experts (MoE) models.
|
||||
|
||||
Attributes:
|
||||
w1 (nn.Module): Linear layer for input-to-hidden transformation.
|
||||
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
||||
w3 (nn.Module): Additional linear layer for feature transformation.
|
||||
"""
|
||||
def __init__(self, dim: int, inter_dim: int):
|
||||
"""
|
||||
Initializes the Expert layer.
|
||||
|
||||
Args:
|
||||
dim (int): Input and output dimensionality.
|
||||
inter_dim (int): Hidden layer dimensionality.
|
||||
"""
|
||||
super().__init__()
|
||||
self.w1 = Linear(dim, inter_dim)
|
||||
self.w2 = Linear(inter_dim, dim)
|
||||
self.w3 = Linear(dim, inter_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the Expert layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after expert computation.
|
||||
"""
|
||||
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class MoE(nn.Module):
|
||||
"""
|
||||
Mixture-of-Experts (MoE) module.
|
||||
|
||||
Attributes:
|
||||
dim (int): Dimensionality of input features.
|
||||
n_routed_experts (int): Total number of experts in the model.
|
||||
n_local_experts (int): Number of experts handled locally in distributed systems.
|
||||
n_activated_experts (int): Number of experts activated for each input.
|
||||
gate (nn.Module): Gating mechanism to route inputs to experts.
|
||||
experts (nn.ModuleList): List of expert modules.
|
||||
shared_experts (nn.Module): Shared experts applied to all inputs.
|
||||
"""
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""
|
||||
Initializes the MoE module.
|
||||
|
||||
Args:
|
||||
args (ModelArgs): Model arguments containing MoE parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
assert args.n_routed_experts % world_size == 0
|
||||
@ -344,6 +664,15 @@ class MoE(nn.Module):
|
||||
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the MoE module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after expert routing and computation.
|
||||
"""
|
||||
shape = x.size()
|
||||
x = x.view(-1, self.dim)
|
||||
weights, indices = self.gate(x)
|
||||
@ -362,7 +691,23 @@ class MoE(nn.Module):
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""
|
||||
Transformer block combining attention and feed-forward layers.
|
||||
|
||||
Attributes:
|
||||
attn (nn.Module): Attention layer (MLA).
|
||||
ffn (nn.Module): Feed-forward network (MLP or MoE).
|
||||
attn_norm (nn.Module): Layer normalization for attention.
|
||||
ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
||||
"""
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
"""
|
||||
Initializes the Transformer block.
|
||||
|
||||
Args:
|
||||
layer_id (int): Layer index in the transformer.
|
||||
args (ModelArgs): Model arguments containing block parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
self.attn = MLA(args)
|
||||
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
|
||||
@ -370,13 +715,42 @@ class Block(nn.Module):
|
||||
self.ffn_norm = RMSNorm(args.dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the Transformer block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
start_pos (int): Starting position in the sequence.
|
||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
||||
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor after block computation.
|
||||
"""
|
||||
x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask)
|
||||
x = x + self.ffn(self.ffn_norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""
|
||||
Transformer model with positional embeddings, multiple layers, and output projection.
|
||||
|
||||
Attributes:
|
||||
max_seq_len (int): Maximum sequence length for the transformer.
|
||||
embed (nn.Module): Embedding layer for input tokens.
|
||||
layers (torch.nn.ModuleList): List of transformer blocks.
|
||||
norm (nn.Module): Layer normalization applied after all blocks.
|
||||
head (nn.Module): Output projection layer mapping to vocabulary size.
|
||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
||||
"""
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""
|
||||
Initializes the Transformer model.
|
||||
|
||||
Args:
|
||||
args (ModelArgs): Model arguments containing transformer parameters.
|
||||
"""
|
||||
global world_size, rank
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
@ -393,6 +767,16 @@ class Transformer(nn.Module):
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
|
||||
"""
|
||||
Forward pass for the Transformer model.
|
||||
|
||||
Args:
|
||||
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
|
||||
start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
||||
"""
|
||||
seqlen = tokens.size(1)
|
||||
h = self.embed(tokens)
|
||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
||||
|
Loading…
Reference in New Issue
Block a user