mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 06:08:58 -05:00
Optimize Multi-head Latent Attention (MLA) for Short Sequences
This commit is contained in:
parent
6a30b43249
commit
cc66d60c67
@ -85,13 +85,6 @@ class ModelArgs:
|
|||||||
|
|
||||||
|
|
||||||
class ParallelEmbedding(nn.Module):
|
class ParallelEmbedding(nn.Module):
|
||||||
"""
|
|
||||||
Embedding layer with parallelism support across distributed processes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab_size (int): Vocabulary size.
|
|
||||||
dim (int): Embedding dimension.
|
|
||||||
"""
|
|
||||||
def __init__(self, vocab_size: int, dim: int):
|
def __init__(self, vocab_size: int, dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@ -103,18 +96,6 @@ class ParallelEmbedding(nn.Module):
|
|||||||
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Forward pass for parallel embedding layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): Input tensor containing token indices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Embedded representations.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `world_size` is not defined.
|
|
||||||
"""
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
||||||
x = x - self.vocab_start_idx
|
x = x - self.vocab_start_idx
|
||||||
@ -162,15 +143,6 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
|||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
"""
|
|
||||||
Custom linear layer with support for quantized weights and optional bias.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_features (int): Number of input features.
|
|
||||||
out_features (int): Number of output features.
|
|
||||||
bias (bool): Whether to include a bias term. Defaults to False.
|
|
||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
|
||||||
"""
|
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||||
@ -190,15 +162,6 @@ class Linear(nn.Module):
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
|
||||||
Forward pass for the custom linear layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): Input tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Transformed tensor after linear computation.
|
|
||||||
"""
|
|
||||||
return linear(x, self.weight, self.bias)
|
return linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
@ -440,7 +403,7 @@ class MLA(nn.Module):
|
|||||||
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
||||||
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the Multi-Headed Attention Layer (MLA).
|
Forward pass for the Multi-Headed Attention Layer (MLA).
|
||||||
|
|
||||||
@ -453,45 +416,63 @@ class MLA(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Output tensor with the same shape as the input.
|
torch.Tensor: Output tensor with the same shape as the input.
|
||||||
"""
|
"""
|
||||||
bsz, seqlen, _ = x.size()
|
bsz, seqlen, _ = x.shape
|
||||||
end_pos = start_pos + seqlen
|
|
||||||
|
# Optimization for small sequence lengths
|
||||||
|
use_efficient_attn = seqlen <= 256 and mask is None
|
||||||
|
|
||||||
if self.q_lora_rank == 0:
|
if self.q_lora_rank == 0:
|
||||||
q = self.wq(x)
|
q = self.wq(x)
|
||||||
else:
|
else:
|
||||||
q = self.wq_b(self.q_norm(self.wq_a(x)))
|
q = self.wq_b(self.q_norm(self.wq_a(x)))
|
||||||
|
|
||||||
|
kv_out = self.wkv_a(x)
|
||||||
|
kv_pe, kv_in = kv_out[:, :, :self.qk_rope_head_dim], kv_out[:, :, self.qk_rope_head_dim:]
|
||||||
|
kv_in = self.wkv_b(self.kv_norm(kv_in))
|
||||||
|
k_nope, v = kv_in[:, :, :self.n_local_heads*self.qk_nope_head_dim], kv_in[:, :, self.n_local_heads*self.qk_nope_head_dim:]
|
||||||
|
|
||||||
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
||||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
k_nope = k_nope.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim)
|
||||||
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
v = v.view(bsz, seqlen, self.n_local_heads, self.v_head_dim)
|
||||||
kv = self.wkv_a(x)
|
|
||||||
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
q_rope, q_nope = q[:, :, :, :self.qk_rope_head_dim], q[:, :, :, self.qk_rope_head_dim:]
|
||||||
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
k_rope = kv_pe.view(bsz, seqlen, self.n_local_heads, self.qk_rope_head_dim)
|
||||||
|
|
||||||
if attn_impl == "naive":
|
if attn_impl == "naive":
|
||||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
self.k_cache[: bsz, start_pos: start_pos + seqlen] = torch.cat([k_rope, k_nope], dim=-1)
|
||||||
kv = self.wkv_b(self.kv_norm(kv))
|
self.v_cache[: bsz, start_pos: start_pos + seqlen] = v
|
||||||
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
k = self.k_cache[: bsz, : start_pos + seqlen]
|
||||||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
v = self.v_cache[: bsz, : start_pos + seqlen]
|
||||||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
|
||||||
self.k_cache[:bsz, start_pos:end_pos] = k
|
|
||||||
self.v_cache[:bsz, start_pos:end_pos] = v
|
|
||||||
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
|
|
||||||
else:
|
else:
|
||||||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
self.kv_cache[: bsz, start_pos: start_pos + seqlen] = kv_in
|
||||||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
self.pe_cache[: bsz, start_pos: start_pos + seqlen] = kv_pe
|
||||||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
k = torch.cat([k_rope, k_nope], dim=-1)
|
||||||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
|
||||||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
q = apply_rotary_emb(q_rope, freqs_cis)
|
||||||
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
k = apply_rotary_emb(k_rope, freqs_cis)
|
||||||
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
|
||||||
if mask is not None:
|
if use_efficient_attn:
|
||||||
scores += mask.unsqueeze(1)
|
# Efficient attention for small sequences
|
||||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
q = q.transpose(1, 2) # (bsz, n_local_heads, seqlen, head_dim)
|
||||||
if attn_impl == "naive":
|
k = k.transpose(1, 2)
|
||||||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
v = v.transpose(1, 2)
|
||||||
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
||||||
|
scores = F.softmax(scores, dim=-1)
|
||||||
|
output = torch.matmul(scores, v)
|
||||||
else:
|
else:
|
||||||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
# Regular attention computation
|
||||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
q = q.transpose(1, 2)
|
||||||
x = self.wo(x.flatten(2))
|
k = k.transpose(1, 2)
|
||||||
return x
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores + mask
|
||||||
|
scores = F.softmax(scores, dim=-1)
|
||||||
|
output = torch.matmul(scores, v)
|
||||||
|
|
||||||
|
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||||
|
return self.wo(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -757,7 +738,7 @@ class Transformer(nn.Module):
|
|||||||
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_seq_len = args.max_seq_len
|
self.max_seq_len = args.max_seq_len
|
||||||
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
self.embed = ParallelEmbedding(args.vocab_size, args.dim, memory_efficient=True)
|
||||||
self.layers = torch.nn.ModuleList()
|
self.layers = torch.nn.ModuleList()
|
||||||
for layer_id in range(args.n_layers):
|
for layer_id in range(args.n_layers):
|
||||||
self.layers.append(Block(layer_id, args))
|
self.layers.append(Block(layer_id, args))
|
||||||
|
Loading…
Reference in New Issue
Block a user