diff --git a/deepseek_vl2/models/siglip_vit.py b/deepseek_vl2/models/siglip_vit.py index 462d5a5..cc102a4 100644 --- a/deepseek_vl2/models/siglip_vit.py +++ b/deepseek_vl2/models/siglip_vit.py @@ -12,11 +12,15 @@ from timm.layers import ( AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType ) from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv -from flash_attn import flash_attn_qkvpacked_func +from transformers.modeling_utils import is_flash_attn_2_available from xformers.ops import memory_efficient_attention from functools import partial +if is_flash_attn_2_available(): + from flash_attn import flash_attn_qkvpacked_func + + def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf @@ -134,7 +138,7 @@ class Attention(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) if not self.qk_norm: - if self.head_dim % 32 == 0: + if self.head_dim % 32 == 0 and is_flash_attn_2_available(): # flashattn的head_dim必须是32的倍数,SigLIP-SO400M无法使用flashattn x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0., deterministic=self.deterministic)