mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 06:08:58 -05:00
Optimize Multi-head Latent Attention (MLA) with Fast Path for Short Sequences
This commit is contained in:
parent
f8b7c3b6e7
commit
79d72ecd8d
@ -418,8 +418,8 @@ class MLA(nn.Module):
|
|||||||
"""
|
"""
|
||||||
bsz, seqlen, _ = x.shape
|
bsz, seqlen, _ = x.shape
|
||||||
|
|
||||||
# Optimization for small sequence lengths
|
# Fast path for short sequences without masks
|
||||||
use_efficient_attn = seqlen <= 256 and mask is None
|
use_fast_path = seqlen <= 256 and mask is None
|
||||||
|
|
||||||
if self.q_lora_rank == 0:
|
if self.q_lora_rank == 0:
|
||||||
q = self.wq(x)
|
q = self.wq(x)
|
||||||
@ -451,16 +451,20 @@ class MLA(nn.Module):
|
|||||||
q = apply_rotary_emb(q_rope, freqs_cis)
|
q = apply_rotary_emb(q_rope, freqs_cis)
|
||||||
k = apply_rotary_emb(k_rope, freqs_cis)
|
k = apply_rotary_emb(k_rope, freqs_cis)
|
||||||
|
|
||||||
if use_efficient_attn:
|
if use_fast_path:
|
||||||
# Efficient attention for small sequences
|
# Optimized path for short sequences
|
||||||
q = q.transpose(1, 2) # (bsz, n_local_heads, seqlen, head_dim)
|
q = q.transpose(1, 2) # [bsz, n_local_heads, seqlen, head_dim]
|
||||||
k = k.transpose(1, 2)
|
k = k.transpose(1, 2)
|
||||||
v = v.transpose(1, 2)
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
# Single matmul for attention scores
|
||||||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
||||||
scores = F.softmax(scores, dim=-1)
|
scores = F.softmax(scores, dim=-1, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Single matmul for output computation
|
||||||
output = torch.matmul(scores, v)
|
output = torch.matmul(scores, v)
|
||||||
else:
|
else:
|
||||||
# Regular attention computation
|
# Standard path for longer sequences or when mask is needed
|
||||||
q = q.transpose(1, 2)
|
q = q.transpose(1, 2)
|
||||||
k = k.transpose(1, 2)
|
k = k.transpose(1, 2)
|
||||||
v = v.transpose(1, 2)
|
v = v.transpose(1, 2)
|
||||||
@ -468,7 +472,7 @@ class MLA(nn.Module):
|
|||||||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask
|
scores = scores + mask
|
||||||
scores = F.softmax(scores, dim=-1)
|
scores = F.softmax(scores, dim=-1, dtype=torch.float32)
|
||||||
output = torch.matmul(scores, v)
|
output = torch.matmul(scores, v)
|
||||||
|
|
||||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||||
|
Loading…
Reference in New Issue
Block a user