Absorb w_uk into wo

This commit is contained in:
Xu Song 2025-02-22 14:50:32 +08:00 committed by GitHub
parent f09f5fa321
commit 57582c60f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -428,6 +428,7 @@ class MLA(nn.Module):
self.kv_norm = RMSNorm(self.kv_lora_rank) self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.wo_absorb = None
self.softmax_scale = self.qk_head_dim ** -0.5 self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len: if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
@ -487,10 +488,16 @@ class MLA(nn.Module):
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
if attn_impl == "naive": if attn_impl == "naive":
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
x = self.wo(x.flatten(2))
else: else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) if self.wo_absorb is None:
x = self.wo(x.flatten(2)) wo = self.wo.weight
wo = wo.transpose(0,1).view(self.n_heads, self.v_head_dim, self.dim)
self.wo_absorb = torch.einsum("hdc,hdi->hci", wkv_b[:, -self.v_head_dim:], wo)
x = torch.einsum("bshc,hci->bshi", x, self.wo_absorb)
x = torch.sum(x, dim=2)
return x return x