diff --git a/deepseek_vl2/models/modeling_deepseek.py b/deepseek_vl2/models/modeling_deepseek.py index 1a84a01..3105fe2 100644 --- a/deepseek_vl2/models/modeling_deepseek.py +++ b/deepseek_vl2/models/modeling_deepseek.py @@ -882,8 +882,18 @@ class DeepseekV2Attention(nn.Module): if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models compressed_kv = compressed_kv.unsqueeze(1) + + k_pe_device = k_pe.device + compressed_kv_device = compressed_kv.device + if len(past_key_value.key_cache) == 0: + compressed_kv = compressed_kv.to(k_pe_device) + else: + k_pe = k_pe.to(past_key_value.key_cache[0].device) + compressed_kv = compressed_kv.to(past_key_value.value_cache[0].device) k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) compressed_kv = compressed_kv.squeeze(1) + k_pe = k_pe.to(k_pe_device) + compressed_kv = compressed_kv.to(compressed_kv_device) kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]