diff --git a/inference/model.py b/inference/model.py index ad09129..62f4caa 100644 --- a/inference/model.py +++ b/inference/model.py @@ -418,8 +418,8 @@ class MLA(nn.Module): """ bsz, seqlen, _ = x.shape - # Optimization for small sequence lengths - use_efficient_attn = seqlen <= 256 and mask is None + # Fast path for short sequences without masks + use_fast_path = seqlen <= 256 and mask is None if self.q_lora_rank == 0: q = self.wq(x) @@ -451,16 +451,20 @@ class MLA(nn.Module): q = apply_rotary_emb(q_rope, freqs_cis) k = apply_rotary_emb(k_rope, freqs_cis) - if use_efficient_attn: - # Efficient attention for small sequences - q = q.transpose(1, 2) # (bsz, n_local_heads, seqlen, head_dim) + if use_fast_path: + # Optimized path for short sequences + q = q.transpose(1, 2) # [bsz, n_local_heads, seqlen, head_dim] k = k.transpose(1, 2) v = v.transpose(1, 2) + + # Single matmul for attention scores scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale - scores = F.softmax(scores, dim=-1) + scores = F.softmax(scores, dim=-1, dtype=torch.float32) + + # Single matmul for output computation output = torch.matmul(scores, v) else: - # Regular attention computation + # Standard path for longer sequences or when mask is needed q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -468,7 +472,7 @@ class MLA(nn.Module): scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale if mask is not None: scores = scores + mask - scores = F.softmax(scores, dim=-1) + scores = F.softmax(scores, dim=-1, dtype=torch.float32) output = torch.matmul(scores, v) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)