fix flash_attn import on old GPU

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-01-10 01:55:19 +08:00
parent 66ec91081c
commit a60e8208bb

View File

@ -12,11 +12,15 @@ from timm.layers import (
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
) )
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from flash_attn import flash_attn_qkvpacked_func from transformers.modeling_utils import is_flash_attn_2_available
from xformers.ops import memory_efficient_attention from xformers.ops import memory_efficient_attention
from functools import partial from functools import partial
if is_flash_attn_2_available():
from flash_attn import flash_attn_qkvpacked_func
def _no_grad_trunc_normal_(tensor, mean, std, a, b): def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@ -134,7 +138,7 @@ class Attention(nn.Module):
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
if not self.qk_norm: if not self.qk_norm:
if self.head_dim % 32 == 0: if self.head_dim % 32 == 0 and is_flash_attn_2_available():
# flashattn的head_dim必须是32的倍数SigLIP-SO400M无法使用flashattn # flashattn的head_dim必须是32的倍数SigLIP-SO400M无法使用flashattn
x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0., x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0.,
deterministic=self.deterministic) deterministic=self.deterministic)