Fix the model parallelism error in DeepSeek-VL2

This commit is contained in:
kongsz 2025-02-20 16:29:44 +08:00
parent 3c2cd21c9f
commit 8ed2d7ff8d

View File

@ -882,8 +882,18 @@ class DeepseekV2Attention(nn.Module):
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
compressed_kv = compressed_kv.unsqueeze(1)
k_pe_device = k_pe.device
compressed_kv_device = compressed_kv.device
if len(past_key_value.key_cache) == 0:
compressed_kv = compressed_kv.to(k_pe_device)
else:
k_pe = k_pe.to(past_key_value.key_cache[0].device)
compressed_kv = compressed_kv.to(past_key_value.value_cache[0].device)
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
compressed_kv = compressed_kv.squeeze(1)
k_pe = k_pe.to(k_pe_device)
compressed_kv = compressed_kv.to(compressed_kv_device)
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]