diff --git a/deepseek_vl2/models/siglip_vit.py b/deepseek_vl2/models/siglip_vit.py index 67f30e8..f06c25f 100644 --- a/deepseek_vl2/models/siglip_vit.py +++ b/deepseek_vl2/models/siglip_vit.py @@ -13,7 +13,6 @@ from timm.layers import ( ) from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv from transformers.modeling_utils import is_flash_attn_2_available -from xformers.ops import memory_efficient_attention from functools import partial @@ -134,6 +133,8 @@ class Attention(nn.Module): self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: + from xformers.ops import memory_efficient_attention + B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)