mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-22 21:58:58 -05:00
Absorb w_uk into wo
This commit is contained in:
parent
f09f5fa321
commit
57582c60f4
@ -428,6 +428,7 @@ class MLA(nn.Module):
|
|||||||
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
||||||
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
||||||
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
||||||
|
self.wo_absorb = None
|
||||||
self.softmax_scale = self.qk_head_dim ** -0.5
|
self.softmax_scale = self.qk_head_dim ** -0.5
|
||||||
if args.max_seq_len > args.original_seq_len:
|
if args.max_seq_len > args.original_seq_len:
|
||||||
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
||||||
@ -487,10 +488,16 @@ class MLA(nn.Module):
|
|||||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
||||||
if attn_impl == "naive":
|
if attn_impl == "naive":
|
||||||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
||||||
|
x = self.wo(x.flatten(2))
|
||||||
else:
|
else:
|
||||||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
||||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
if self.wo_absorb is None:
|
||||||
x = self.wo(x.flatten(2))
|
wo = self.wo.weight
|
||||||
|
wo = wo.transpose(0,1).view(self.n_heads, self.v_head_dim, self.dim)
|
||||||
|
self.wo_absorb = torch.einsum("hdc,hdi->hci", wkv_b[:, -self.v_head_dim:], wo)
|
||||||
|
|
||||||
|
x = torch.einsum("bshc,hci->bshi", x, self.wo_absorb)
|
||||||
|
x = torch.sum(x, dim=2)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user