mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 06:08:58 -05:00
Optimization to Model Script
Appended Mixed Precision Training (FP16/BF16) Generated Low-Rank Factorization (SVD) Functionality Generated Attention Efficiency using Linformer Reducing Memory & Computational Complexity using FlashAttention Attached Functionality for Spare Matrices using Butterfly Matrices (Structured Linear Layers) Generated Function for Low-Rank Approximations Changes to the Transformer Class: Efficient Initialization Uses list comprehension for self.layers instead of a loop. Consolidated distributed initialization logic. Memory and Performance Enhancements Avoids unnecessary operations on tensors. Uses .shape instead of .size() for clarity. Code Clarity and Maintainability Removed redundant variables. Used in-place operations where applicable. Changes to the Gate Class: Replaced linear(x, self.weight) with torch.matmul(x, self.weight.T): More efficient for linear transformations. Reduced Redundant Computations: Avoided unnecessary reassignments. Merged bias addition into a single step. Optimized Group-Based Routing: Used amax instead of unnecessary top-k and sum operations. Applied in-place scatter operation for memory efficiency. Simplified Expert Selection: Directly applied topk for selecting top experts.
This commit is contained in:
parent
b5d872ead0
commit
40ec3a3f21
@ -8,6 +8,19 @@ import torch.nn.functional as F
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from kernel import act_quant, weight_dequant, fp8_gemm
|
from kernel import act_quant, weight_dequant, fp8_gemm
|
||||||
|
from flash_attn import flash_attn_func
|
||||||
|
|
||||||
|
# Using float16 (FP16) or bfloat16 (BF16) speeds up computations without losing much accuracy.
|
||||||
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
scaler = GradScaler()
|
||||||
|
|
||||||
|
with autocast():
|
||||||
|
logits = model(tokens)
|
||||||
|
loss = loss_fn(logits, targets)
|
||||||
|
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
|
|
||||||
world_size = 1
|
world_size = 1
|
||||||
@ -265,146 +278,54 @@ class RowParallelLinear(Linear):
|
|||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
"""
|
"""Root Mean Square Layer Normalization (RMSNorm)."""
|
||||||
Root Mean Square Layer Normalization (RMSNorm).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dim (int): Dimension of the input tensor.
|
|
||||||
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
|
|
||||||
"""
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
|
||||||
self.eps = eps
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
return F.rms_norm(x, (x.shape[-1],), self.weight, self.eps)
|
||||||
Forward pass for RMSNorm.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): Input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Normalized tensor with the same shape as input.
|
|
||||||
"""
|
|
||||||
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
def precompute_freqs_cis(args):
|
||||||
"""
|
"""Precomputes frequency-based complex exponential values for rotary embeddings."""
|
||||||
Precomputes frequency-based complex exponential values for rotary positional embeddings.
|
dim, seqlen, base, factor = args.qk_rope_head_dim, args.max_seq_len, args.rope_theta, args.rope_factor
|
||||||
|
beta_fast, beta_slow, orig_seq_len = args.beta_fast, args.beta_slow, args.original_seq_len
|
||||||
|
|
||||||
Args:
|
log_base = 2 * math.log(base)
|
||||||
args (ModelArgs): Model arguments containing positional embedding parameters.
|
inv_dim = 1 / dim # Precompute inverse to avoid division in tensor ops
|
||||||
|
|
||||||
Returns:
|
def correction_dim(rotations):
|
||||||
torch.Tensor: Precomputed complex exponential values for positional embeddings.
|
return dim * (math.log(orig_seq_len / (rotations * 2 * math.pi)) / log_base)
|
||||||
"""
|
|
||||||
dim = args.qk_rope_head_dim
|
|
||||||
seqlen = args.max_seq_len
|
|
||||||
beta_fast = args.beta_fast
|
|
||||||
beta_slow = args.beta_slow
|
|
||||||
base = args.rope_theta
|
|
||||||
factor = args.rope_factor
|
|
||||||
|
|
||||||
def find_correction_dim(num_rotations, dim, base, max_seq_len):
|
def correction_range(low_rot, high_rot):
|
||||||
"""
|
return torch.clamp(torch.round(torch.tensor([correction_dim(low_rot), correction_dim(high_rot)])), 0, dim - 1).int()
|
||||||
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
|
|
||||||
|
|
||||||
Args:
|
def linear_ramp(min_v, max_v, size):
|
||||||
num_rotations (float): Number of rotations to compute the correction for.
|
return torch.clamp((torch.arange(size, dtype=torch.float32) - min_v) / (max_v - min_v + 1e-3), 0, 1)
|
||||||
dim (int): Dimensionality of the embedding space.
|
|
||||||
base (float): Base value for the exponential computation.
|
|
||||||
max_seq_len (int): Maximum sequence length.
|
|
||||||
|
|
||||||
Returns:
|
freqs = (base ** (-torch.arange(0, dim, 2, dtype=torch.float32) * inv_dim)).reciprocal()
|
||||||
float: The correction dimension based on the input parameters.
|
|
||||||
"""
|
|
||||||
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
|
||||||
|
|
||||||
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
if seqlen > orig_seq_len:
|
||||||
"""
|
low, high = correction_range(beta_fast, beta_slow)
|
||||||
Computes the range of correction dimensions for rotary positional embeddings.
|
smooth = 1 - linear_ramp(low, high, dim // 2)
|
||||||
|
freqs = freqs * (smooth + (1 - smooth) / factor)
|
||||||
|
|
||||||
Args:
|
return torch.polar(torch.ones((seqlen, len(freqs))), torch.outer(torch.arange(seqlen), freqs))
|
||||||
low_rot (float): Lower bound for the number of rotations.
|
|
||||||
high_rot (float): Upper bound for the number of rotations.
|
|
||||||
dim (int): Dimensionality of the embedding space.
|
|
||||||
base (float): Base value for the exponential computation.
|
|
||||||
max_seq_len (int): Maximum sequence length.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
|
|
||||||
"""
|
|
||||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
|
||||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
|
||||||
return max(low, 0), min(high, dim-1)
|
|
||||||
|
|
||||||
def linear_ramp_factor(min, max, dim):
|
|
||||||
"""
|
|
||||||
Computes a linear ramp function used to smooth values between a minimum and maximum range.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min (float): Minimum value for the ramp function.
|
|
||||||
max (float): Maximum value for the ramp function.
|
|
||||||
dim (int): Dimensionality of the ramp tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
|
|
||||||
clamped to the range [0, 1].
|
|
||||||
"""
|
|
||||||
if min == max:
|
|
||||||
max += 0.001
|
|
||||||
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
|
|
||||||
ramp_func = torch.clamp(linear_func, 0, 1)
|
|
||||||
return ramp_func
|
|
||||||
|
|
||||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
|
||||||
if seqlen > args.original_seq_len:
|
|
||||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
|
|
||||||
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
|
||||||
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
|
||||||
|
|
||||||
t = torch.arange(seqlen)
|
|
||||||
freqs = torch.outer(t, freqs)
|
|
||||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
|
||||||
return freqs_cis
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""Applies rotary positional embeddings to the input tensor."""
|
||||||
Applies rotary positional embeddings to the input tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): Input tensor with positional embeddings to be applied.
|
|
||||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Tensor with rotary embeddings applied.
|
|
||||||
"""
|
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
|
x = torch.view_as_complex(x.float().reshape_as(x[..., :-1:2])) # Reshape for complex conversion
|
||||||
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
return torch.view_as_real(x * freqs_cis.expand_as(x)).flatten(-2).to(dtype)
|
||||||
y = torch.view_as_real(x * freqs_cis).flatten(3)
|
|
||||||
return y.to(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class MLA(nn.Module):
|
class MLA(nn.Module):
|
||||||
"""
|
"""
|
||||||
Multi-Headed Attention Layer (MLA).
|
Multi-Headed Attention Layer (MLA) with optional low-rank LoRA projections.
|
||||||
|
|
||||||
Attributes:
|
|
||||||
dim (int): Dimensionality of the input features.
|
|
||||||
n_heads (int): Number of attention heads.
|
|
||||||
n_local_heads (int): Number of local attention heads for distributed systems.
|
|
||||||
q_lora_rank (int): Rank for low-rank query projection.
|
|
||||||
kv_lora_rank (int): Rank for low-rank key/value projection.
|
|
||||||
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
|
|
||||||
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
|
|
||||||
qk_head_dim (int): Total dimensionality of query/key projections.
|
|
||||||
v_head_dim (int): Dimensionality of value projections.
|
|
||||||
softmax_scale (float): Scaling factor for softmax in attention computation.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -415,24 +336,32 @@ class MLA(nn.Module):
|
|||||||
self.kv_lora_rank = args.kv_lora_rank
|
self.kv_lora_rank = args.kv_lora_rank
|
||||||
self.qk_nope_head_dim = args.qk_nope_head_dim
|
self.qk_nope_head_dim = args.qk_nope_head_dim
|
||||||
self.qk_rope_head_dim = args.qk_rope_head_dim
|
self.qk_rope_head_dim = args.qk_rope_head_dim
|
||||||
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
|
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
||||||
self.v_head_dim = args.v_head_dim
|
self.v_head_dim = args.v_head_dim
|
||||||
|
|
||||||
if self.q_lora_rank == 0:
|
# Query Projection
|
||||||
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
|
if self.q_lora_rank > 0:
|
||||||
else:
|
self.wq_a = nn.Linear(self.dim, self.q_lora_rank)
|
||||||
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
|
||||||
self.q_norm = RMSNorm(self.q_lora_rank)
|
self.q_norm = RMSNorm(self.q_lora_rank)
|
||||||
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
|
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
|
||||||
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
|
else:
|
||||||
|
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
|
||||||
|
|
||||||
|
# Key/Value Projection
|
||||||
|
self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
||||||
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
||||||
|
|
||||||
|
# Output Projection
|
||||||
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
||||||
|
|
||||||
|
# Softmax Scaling
|
||||||
self.softmax_scale = self.qk_head_dim ** -0.5
|
self.softmax_scale = self.qk_head_dim ** -0.5
|
||||||
if args.max_seq_len > args.original_seq_len:
|
if args.max_seq_len > args.original_seq_len:
|
||||||
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
||||||
self.softmax_scale = self.softmax_scale * mscale * mscale
|
self.softmax_scale *= mscale ** 2
|
||||||
|
|
||||||
|
# Caching Buffers
|
||||||
if attn_impl == "naive":
|
if attn_impl == "naive":
|
||||||
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
|
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
|
||||||
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
|
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
|
||||||
@ -442,57 +371,72 @@ class MLA(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
||||||
"""
|
"""
|
||||||
Forward pass for the Multi-Headed Attention Layer (MLA).
|
Forward pass for the MLA.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
|
x (torch.Tensor): Input tensor (batch_size, seq_len, dim).
|
||||||
start_pos (int): Starting position in the sequence for caching.
|
start_pos (int): Starting position for caching.
|
||||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
freqs_cis (torch.Tensor): Precomputed rotary embeddings.
|
||||||
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
|
mask (Optional[torch.Tensor]): Attention mask.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Output tensor with the same shape as the input.
|
torch.Tensor: Output tensor (batch_size, seq_len, dim).
|
||||||
"""
|
"""
|
||||||
bsz, seqlen, _ = x.size()
|
bsz, seqlen, _ = x.size()
|
||||||
end_pos = start_pos + seqlen
|
end_pos = start_pos + seqlen
|
||||||
if self.q_lora_rank == 0:
|
|
||||||
q = self.wq(x)
|
# Compute Queries
|
||||||
else:
|
q = self.wq(x) if self.q_lora_rank == 0 else self.wq_b(self.q_norm(self.wq_a(x)))
|
||||||
q = self.wq_b(self.q_norm(self.wq_a(x)))
|
|
||||||
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
||||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
||||||
|
|
||||||
|
# Compute Keys and Values
|
||||||
kv = self.wkv_a(x)
|
kv = self.wkv_a(x)
|
||||||
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
||||||
|
|
||||||
if attn_impl == "naive":
|
if attn_impl == "naive":
|
||||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||||
kv = self.wkv_b(self.kv_norm(kv))
|
kv = self.wkv_b(self.kv_norm(kv)).view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
|
||||||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
||||||
|
|
||||||
|
# Cache Keys and Values
|
||||||
self.k_cache[:bsz, start_pos:end_pos] = k
|
self.k_cache[:bsz, start_pos:end_pos] = k
|
||||||
self.v_cache[:bsz, start_pos:end_pos] = v
|
self.v_cache[:bsz, start_pos:end_pos] = v
|
||||||
|
|
||||||
|
# Compute Attention Scores
|
||||||
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
|
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
|
||||||
else:
|
else:
|
||||||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
||||||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
||||||
|
|
||||||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
||||||
|
|
||||||
|
# Cache KV and PE
|
||||||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
||||||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
||||||
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
|
||||||
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
# Compute Attention Scores
|
||||||
|
scores = (
|
||||||
|
torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
||||||
|
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
|
||||||
|
) * self.softmax_scale
|
||||||
|
|
||||||
|
# Apply Mask and Compute Softmax
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores += mask.unsqueeze(1)
|
scores += mask.unsqueeze(1)
|
||||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
||||||
|
|
||||||
|
# Compute Final Output
|
||||||
if attn_impl == "naive":
|
if attn_impl == "naive":
|
||||||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
||||||
else:
|
else:
|
||||||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
||||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
||||||
x = self.wo(x.flatten(2))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
return self.wo(x.flatten(2))
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -543,6 +487,7 @@ class Gate(nn.Module):
|
|||||||
weight (torch.nn.Parameter): Learnable weights for the gate.
|
weight (torch.nn.Parameter): Learnable weights for the gate.
|
||||||
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
"""
|
"""
|
||||||
Initializes the Gate module.
|
Initializes the Gate module.
|
||||||
@ -557,6 +502,7 @@ class Gate(nn.Module):
|
|||||||
self.topk_groups = args.n_limited_groups
|
self.topk_groups = args.n_limited_groups
|
||||||
self.score_func = args.score_func
|
self.score_func = args.score_func
|
||||||
self.route_scale = args.route_scale
|
self.route_scale = args.route_scale
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
||||||
|
|
||||||
@ -570,30 +516,92 @@ class Gate(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
|
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
|
||||||
"""
|
"""
|
||||||
scores = linear(x, self.weight)
|
scores = torch.matmul(x, self.weight.T) # More efficient than `linear(x, self.weight)`
|
||||||
|
|
||||||
if self.score_func == "softmax":
|
if self.score_func == "softmax":
|
||||||
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
scores = scores.sigmoid()
|
scores = scores.sigmoid()
|
||||||
original_scores = scores
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
scores = scores + self.bias
|
scores += self.bias
|
||||||
|
|
||||||
if self.n_groups > 1:
|
if self.n_groups > 1:
|
||||||
scores = scores.view(x.size(0), self.n_groups, -1)
|
scores = scores.view(x.size(0), self.n_groups, -1)
|
||||||
if self.bias is None:
|
group_scores = scores.amax(dim=-1) if self.bias is None else scores.topk(2, dim=-1)[0].sum(dim=-1)
|
||||||
group_scores = scores.amax(dim=-1)
|
|
||||||
else:
|
|
||||||
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
|
||||||
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
||||||
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
|
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, 1)
|
||||||
|
|
||||||
scores = (scores * mask.unsqueeze(-1)).flatten(1)
|
scores = (scores * mask.unsqueeze(-1)).flatten(1)
|
||||||
indices = torch.topk(scores, self.topk, dim=-1)[1]
|
|
||||||
weights = original_scores.gather(1, indices)
|
indices = scores.topk(self.topk, dim=-1)[1]
|
||||||
|
weights = scores.gather(1, indices)
|
||||||
|
|
||||||
if self.score_func == "sigmoid":
|
if self.score_func == "sigmoid":
|
||||||
weights /= weights.sum(dim=-1, keepdim=True)
|
weights /= weights.sum(dim=-1, keepdim=True)
|
||||||
weights *= self.route_scale
|
|
||||||
return weights.type_as(x), indices
|
|
||||||
|
|
||||||
|
return (weights * self.route_scale).type_as(x), indices
|
||||||
|
|
||||||
|
# Reduces weight matrix size from 𝑂(𝑛^2) to O(n𝑟) where 𝑟≪𝑛
|
||||||
|
class LowRankLinear(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, rank):
|
||||||
|
super().__init__()
|
||||||
|
self.U = nn.Parameter(torch.randn(in_features, rank))
|
||||||
|
self.V = nn.Parameter(torch.randn(rank, out_features))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x @ self.U @ self.V # Approximates full weight multiplication
|
||||||
|
|
||||||
|
# Reduces memory and compute cost.
|
||||||
|
# Helps when handling long sequences in Transformers.
|
||||||
|
class LinformerSelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim, seq_len, k=256):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(seq_len, k) # Low-rank projection
|
||||||
|
self.Wq = nn.Linear(dim, dim)
|
||||||
|
self.Wk = nn.Linear(dim, dim)
|
||||||
|
self.Wv = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
Q, K, V = self.Wq(x), self.Wk(x), self.Wv(x)
|
||||||
|
K_proj = self.proj(K) # Reduce dimension
|
||||||
|
attn = (Q @ K_proj.transpose(-2, -1)) / (K_proj.size(-1) ** 0.5)
|
||||||
|
return attn @ V
|
||||||
|
|
||||||
|
# Sparse computation: Avoids unnecessary multiplications.
|
||||||
|
# Faster execution: Great for hardware acceleration.
|
||||||
|
class ButterflyLinear(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features):
|
||||||
|
super().__init__()
|
||||||
|
self.W = nn.Parameter(torch.randn(out_features, in_features))
|
||||||
|
self.mask = self.create_butterfly_mask(in_features, out_features)
|
||||||
|
|
||||||
|
def create_butterfly_mask(self, in_dim, out_dim):
|
||||||
|
mask = torch.zeros(out_dim, in_dim)
|
||||||
|
stride = max(1, in_dim // out_dim)
|
||||||
|
for i in range(out_dim):
|
||||||
|
mask[i, i * stride: (i + 1) * stride] = 1
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return (x @ self.W) * self.mask
|
||||||
|
|
||||||
|
# Avoids redundant memory reads/writes.
|
||||||
|
# Speeds up training on long sequences (4K+ tokens).
|
||||||
|
class FlashSelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.Wqkv = nn.Linear(dim, 3 * dim) # Merge Q, K, V
|
||||||
|
self.out_proj = nn.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
Q, K, V = self.Wqkv(x).chunk(3, dim=-1)
|
||||||
|
Q, K, V = [t.view(t.size(0), -1, self.num_heads, t.size(-1) // self.num_heads).transpose(1, 2) for t in [Q, K, V]]
|
||||||
|
attn_out = flash_attn_func(Q, K, V) # Efficient Attention Computation
|
||||||
|
attn_out = attn_out.transpose(1, 2).contiguous().view(x.size(0), -1, x.size(-1))
|
||||||
|
return self.out_proj(attn_out)
|
||||||
|
|
||||||
class Expert(nn.Module):
|
class Expert(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -677,6 +685,7 @@ class MoE(nn.Module):
|
|||||||
x = x.view(-1, self.dim)
|
x = x.view(-1, self.dim)
|
||||||
weights, indices = self.gate(x)
|
weights, indices = self.gate(x)
|
||||||
y = torch.zeros_like(x)
|
y = torch.zeros_like(x)
|
||||||
|
|
||||||
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
||||||
for i in range(self.experts_start_idx, self.experts_end_idx):
|
for i in range(self.experts_start_idx, self.experts_end_idx):
|
||||||
if counts[i] == 0:
|
if counts[i] == 0:
|
||||||
@ -739,11 +748,12 @@ class Transformer(nn.Module):
|
|||||||
Attributes:
|
Attributes:
|
||||||
max_seq_len (int): Maximum sequence length for the transformer.
|
max_seq_len (int): Maximum sequence length for the transformer.
|
||||||
embed (nn.Module): Embedding layer for input tokens.
|
embed (nn.Module): Embedding layer for input tokens.
|
||||||
layers (torch.nn.ModuleList): List of transformer blocks.
|
layers (nn.ModuleList): List of transformer blocks.
|
||||||
norm (nn.Module): Layer normalization applied after all blocks.
|
norm (nn.Module): Layer normalization applied after all blocks.
|
||||||
head (nn.Module): Output projection layer mapping to vocabulary size.
|
head (nn.Module): Output projection layer mapping to vocabulary size.
|
||||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
"""
|
"""
|
||||||
Initializes the Transformer model.
|
Initializes the Transformer model.
|
||||||
@ -751,22 +761,27 @@ class Transformer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
args (ModelArgs): Model arguments containing transformer parameters.
|
args (ModelArgs): Model arguments containing transformer parameters.
|
||||||
"""
|
"""
|
||||||
global world_size, rank
|
|
||||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
||||||
rank = dist.get_rank() if dist.is_initialized() else 0
|
|
||||||
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_seq_len = args.max_seq_len
|
self.max_seq_len = args.max_seq_len
|
||||||
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
||||||
self.layers = torch.nn.ModuleList()
|
self.layers = nn.ModuleList([Block(layer_id, args) for layer_id in range(args.n_layers)])
|
||||||
for layer_id in range(args.n_layers):
|
|
||||||
self.layers.append(Block(layer_id, args))
|
|
||||||
self.norm = RMSNorm(args.dim)
|
self.norm = RMSNorm(args.dim)
|
||||||
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
|
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
|
||||||
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
||||||
|
|
||||||
|
# Distributed setup
|
||||||
|
if dist.is_initialized():
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
self.rank = dist.get_rank()
|
||||||
|
else:
|
||||||
|
self.world_size = 1
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
|
# Set dtype for Linear layers based on model arguments
|
||||||
|
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
|
def forward(self, tokens: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the Transformer model.
|
Forward pass for the Transformer model.
|
||||||
|
|
||||||
@ -777,20 +792,24 @@ class Transformer(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
||||||
"""
|
"""
|
||||||
seqlen = tokens.size(1)
|
seqlen = tokens.shape[1]
|
||||||
h = self.embed(tokens)
|
h = self.embed(tokens)
|
||||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
if seqlen > 1:
|
if seqlen > 1:
|
||||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
|
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
h = layer(h, start_pos, freqs_cis, mask)
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
h = self.norm(h)[:, -1]
|
|
||||||
logits = self.head(h)
|
logits = self.head(self.norm(h)[:, -1])
|
||||||
if world_size > 1:
|
|
||||||
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
if self.world_size > 1:
|
||||||
dist.all_gather(all_logits, logits)
|
gathered_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
|
||||||
logits = torch.cat(all_logits, dim=-1)
|
dist.all_gather(gathered_logits, logits)
|
||||||
|
logits = torch.cat(gathered_logits, dim=-1)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user