Optimize Multi-head Latent Attention (MLA) with Fast Path for Short Sequences

This commit is contained in:
XxAlonexX 2025-02-19 10:35:28 +05:30
parent f8b7c3b6e7
commit 79d72ecd8d

View File

@ -418,8 +418,8 @@ class MLA(nn.Module):
""" """
bsz, seqlen, _ = x.shape bsz, seqlen, _ = x.shape
# Optimization for small sequence lengths # Fast path for short sequences without masks
use_efficient_attn = seqlen <= 256 and mask is None use_fast_path = seqlen <= 256 and mask is None
if self.q_lora_rank == 0: if self.q_lora_rank == 0:
q = self.wq(x) q = self.wq(x)
@ -451,16 +451,20 @@ class MLA(nn.Module):
q = apply_rotary_emb(q_rope, freqs_cis) q = apply_rotary_emb(q_rope, freqs_cis)
k = apply_rotary_emb(k_rope, freqs_cis) k = apply_rotary_emb(k_rope, freqs_cis)
if use_efficient_attn: if use_fast_path:
# Efficient attention for small sequences # Optimized path for short sequences
q = q.transpose(1, 2) # (bsz, n_local_heads, seqlen, head_dim) q = q.transpose(1, 2) # [bsz, n_local_heads, seqlen, head_dim]
k = k.transpose(1, 2) k = k.transpose(1, 2)
v = v.transpose(1, 2) v = v.transpose(1, 2)
# Single matmul for attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
scores = F.softmax(scores, dim=-1) scores = F.softmax(scores, dim=-1, dtype=torch.float32)
# Single matmul for output computation
output = torch.matmul(scores, v) output = torch.matmul(scores, v)
else: else:
# Regular attention computation # Standard path for longer sequences or when mask is needed
q = q.transpose(1, 2) q = q.transpose(1, 2)
k = k.transpose(1, 2) k = k.transpose(1, 2)
v = v.transpose(1, 2) v = v.transpose(1, 2)
@ -468,7 +472,7 @@ class MLA(nn.Module):
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
if mask is not None: if mask is not None:
scores = scores + mask scores = scores + mask
scores = F.softmax(scores, dim=-1) scores = F.softmax(scores, dim=-1, dtype=torch.float32)
output = torch.matmul(scores, v) output = torch.matmul(scores, v)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)