This commit is contained in:
DeadLining 2025-02-26 13:10:02 +08:00 committed by GitHub
commit d463dcf019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -882,8 +882,18 @@ class DeepseekV2Attention(nn.Module):
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
compressed_kv = compressed_kv.unsqueeze(1) compressed_kv = compressed_kv.unsqueeze(1)
k_pe_device = k_pe.device
compressed_kv_device = compressed_kv.device
if len(past_key_value.key_cache) == 0:
compressed_kv = compressed_kv.to(k_pe_device)
else:
k_pe = k_pe.to(past_key_value.key_cache[0].device)
compressed_kv = compressed_kv.to(past_key_value.value_cache[0].device)
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
compressed_kv = compressed_kv.squeeze(1) compressed_kv = compressed_kv.squeeze(1)
k_pe = k_pe.to(k_pe_device)
compressed_kv = compressed_kv.to(compressed_kv_device)
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :] q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]