mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 06:08:58 -05:00
Optimize Multi-head Latent Attention (MLA) for Short Sequences
This commit is contained in:
parent
6a30b43249
commit
cc66d60c67
@ -85,13 +85,6 @@ class ModelArgs:
|
||||
|
||||
|
||||
class ParallelEmbedding(nn.Module):
|
||||
"""
|
||||
Embedding layer with parallelism support across distributed processes.
|
||||
|
||||
Args:
|
||||
vocab_size (int): Vocabulary size.
|
||||
dim (int): Embedding dimension.
|
||||
"""
|
||||
def __init__(self, vocab_size: int, dim: int):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
@ -103,18 +96,6 @@ class ParallelEmbedding(nn.Module):
|
||||
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for parallel embedding layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor containing token indices.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Embedded representations.
|
||||
|
||||
Raises:
|
||||
ValueError: If `world_size` is not defined.
|
||||
"""
|
||||
if world_size > 1:
|
||||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
||||
x = x - self.vocab_start_idx
|
||||
@ -162,15 +143,6 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""
|
||||
Custom linear layer with support for quantized weights and optional bias.
|
||||
|
||||
Args:
|
||||
in_features (int): Number of input features.
|
||||
out_features (int): Number of output features.
|
||||
bias (bool): Whether to include a bias term. Defaults to False.
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
@ -190,15 +162,6 @@ class Linear(nn.Module):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the custom linear layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed tensor after linear computation.
|
||||
"""
|
||||
return linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
@ -440,7 +403,7 @@ class MLA(nn.Module):
|
||||
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
||||
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the Multi-Headed Attention Layer (MLA).
|
||||
|
||||
@ -453,45 +416,63 @@ class MLA(nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Output tensor with the same shape as the input.
|
||||
"""
|
||||
bsz, seqlen, _ = x.size()
|
||||
end_pos = start_pos + seqlen
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
# Optimization for small sequence lengths
|
||||
use_efficient_attn = seqlen <= 256 and mask is None
|
||||
|
||||
if self.q_lora_rank == 0:
|
||||
q = self.wq(x)
|
||||
else:
|
||||
q = self.wq_b(self.q_norm(self.wq_a(x)))
|
||||
|
||||
kv_out = self.wkv_a(x)
|
||||
kv_pe, kv_in = kv_out[:, :, :self.qk_rope_head_dim], kv_out[:, :, self.qk_rope_head_dim:]
|
||||
kv_in = self.wkv_b(self.kv_norm(kv_in))
|
||||
k_nope, v = kv_in[:, :, :self.n_local_heads*self.qk_nope_head_dim], kv_in[:, :, self.n_local_heads*self.qk_nope_head_dim:]
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
||||
kv = self.wkv_a(x)
|
||||
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
||||
k_nope = k_nope.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_local_heads, self.v_head_dim)
|
||||
|
||||
q_rope, q_nope = q[:, :, :, :self.qk_rope_head_dim], q[:, :, :, self.qk_rope_head_dim:]
|
||||
k_rope = kv_pe.view(bsz, seqlen, self.n_local_heads, self.qk_rope_head_dim)
|
||||
|
||||
if attn_impl == "naive":
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
kv = self.wkv_b(self.kv_norm(kv))
|
||||
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
||||
self.k_cache[:bsz, start_pos:end_pos] = k
|
||||
self.v_cache[:bsz, start_pos:end_pos] = v
|
||||
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
|
||||
self.k_cache[: bsz, start_pos: start_pos + seqlen] = torch.cat([k_rope, k_nope], dim=-1)
|
||||
self.v_cache[: bsz, start_pos: start_pos + seqlen] = v
|
||||
k = self.k_cache[: bsz, : start_pos + seqlen]
|
||||
v = self.v_cache[: bsz, : start_pos + seqlen]
|
||||
else:
|
||||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
||||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
||||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
||||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
||||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
||||
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
||||
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
||||
self.kv_cache[: bsz, start_pos: start_pos + seqlen] = kv_in
|
||||
self.pe_cache[: bsz, start_pos: start_pos + seqlen] = kv_pe
|
||||
k = torch.cat([k_rope, k_nope], dim=-1)
|
||||
|
||||
q = apply_rotary_emb(q_rope, freqs_cis)
|
||||
k = apply_rotary_emb(k_rope, freqs_cis)
|
||||
|
||||
if use_efficient_attn:
|
||||
# Efficient attention for small sequences
|
||||
q = q.transpose(1, 2) # (bsz, n_local_heads, seqlen, head_dim)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
||||
scores = F.softmax(scores, dim=-1)
|
||||
output = torch.matmul(scores, v)
|
||||
else:
|
||||
# Regular attention computation
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
||||
if mask is not None:
|
||||
scores += mask.unsqueeze(1)
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
||||
if attn_impl == "naive":
|
||||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
||||
else:
|
||||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
||||
x = self.wo(x.flatten(2))
|
||||
return x
|
||||
scores = scores + mask
|
||||
scores = F.softmax(scores, dim=-1)
|
||||
output = torch.matmul(scores, v)
|
||||
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
return self.wo(output)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@ -757,7 +738,7 @@ class Transformer(nn.Module):
|
||||
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
|
||||
super().__init__()
|
||||
self.max_seq_len = args.max_seq_len
|
||||
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
|
||||
self.embed = ParallelEmbedding(args.vocab_size, args.dim, memory_efficient=True)
|
||||
self.layers = torch.nn.ModuleList()
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(Block(layer_id, args))
|
||||
|
Loading…
Reference in New Issue
Block a user