mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-22 21:58:58 -05:00
Optimize Multi-head Latent Attention (MLA) with Fast Path for Short Sequences
This commit is contained in:
parent
f8b7c3b6e7
commit
79d72ecd8d
@ -418,8 +418,8 @@ class MLA(nn.Module):
|
||||
"""
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
# Optimization for small sequence lengths
|
||||
use_efficient_attn = seqlen <= 256 and mask is None
|
||||
# Fast path for short sequences without masks
|
||||
use_fast_path = seqlen <= 256 and mask is None
|
||||
|
||||
if self.q_lora_rank == 0:
|
||||
q = self.wq(x)
|
||||
@ -451,16 +451,20 @@ class MLA(nn.Module):
|
||||
q = apply_rotary_emb(q_rope, freqs_cis)
|
||||
k = apply_rotary_emb(k_rope, freqs_cis)
|
||||
|
||||
if use_efficient_attn:
|
||||
# Efficient attention for small sequences
|
||||
q = q.transpose(1, 2) # (bsz, n_local_heads, seqlen, head_dim)
|
||||
if use_fast_path:
|
||||
# Optimized path for short sequences
|
||||
q = q.transpose(1, 2) # [bsz, n_local_heads, seqlen, head_dim]
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# Single matmul for attention scores
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
||||
scores = F.softmax(scores, dim=-1)
|
||||
scores = F.softmax(scores, dim=-1, dtype=torch.float32)
|
||||
|
||||
# Single matmul for output computation
|
||||
output = torch.matmul(scores, v)
|
||||
else:
|
||||
# Regular attention computation
|
||||
# Standard path for longer sequences or when mask is needed
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
@ -468,7 +472,7 @@ class MLA(nn.Module):
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
|
||||
if mask is not None:
|
||||
scores = scores + mask
|
||||
scores = F.softmax(scores, dim=-1)
|
||||
scores = F.softmax(scores, dim=-1, dtype=torch.float32)
|
||||
output = torch.matmul(scores, v)
|
||||
|
||||
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
|
Loading…
Reference in New Issue
Block a user