Merge pull request #34 from Isotr0py/main

Fix flash_attn import in siglip_vit
This commit is contained in:
StevenLiuWen 2025-01-16 11:13:08 +08:00 committed by GitHub
commit 9dd177671e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,11 +12,15 @@ from timm.layers import (
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
)
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from flash_attn import flash_attn_qkvpacked_func
from transformers.modeling_utils import is_flash_attn_2_available
from xformers.ops import memory_efficient_attention
from functools import partial
if is_flash_attn_2_available():
from flash_attn import flash_attn_qkvpacked_func
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@ -134,7 +138,7 @@ class Attention(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
if not self.qk_norm:
if self.head_dim % 32 == 0:
if self.head_dim % 32 == 0 and is_flash_attn_2_available():
# flashattn的head_dim必须是32的倍数SigLIP-SO400M无法使用flashattn
x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0.,
deterministic=self.deterministic)