From 57582c60f4243cdbdd83ffcd5a7d42056e242a89 Mon Sep 17 00:00:00 2001 From: Xu Song Date: Sat, 22 Feb 2025 14:50:32 +0800 Subject: [PATCH] Absorb w_uk into wo --- inference/model.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/inference/model.py b/inference/model.py index 8f1ab81..adf5981 100644 --- a/inference/model.py +++ b/inference/model.py @@ -428,6 +428,7 @@ class MLA(nn.Module): self.kv_norm = RMSNorm(self.kv_lora_rank) self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) + self.wo_absorb = None self.softmax_scale = self.qk_head_dim ** -0.5 if args.max_seq_len > args.original_seq_len: mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 @@ -487,10 +488,16 @@ class MLA(nn.Module): scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) if attn_impl == "naive": x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) + x = self.wo(x.flatten(2)) else: x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) - x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) - x = self.wo(x.flatten(2)) + if self.wo_absorb is None: + wo = self.wo.weight + wo = wo.transpose(0,1).view(self.n_heads, self.v_head_dim, self.dim) + self.wo_absorb = torch.einsum("hdc,hdi->hci", wkv_b[:, -self.v_head_dim:], wo) + + x = torch.einsum("bshc,hci->bshi", x, self.wo_absorb) + x = torch.sum(x, dim=2) return x