mirror of
https://github.com/deepseek-ai/DeepSeek-VL2.git
synced 2025-02-22 13:49:00 -05:00
fix flash_attn import on old GPU
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
66ec91081c
commit
a60e8208bb
@ -12,11 +12,15 @@ from timm.layers import (
|
||||
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
|
||||
)
|
||||
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from flash_attn import flash_attn_qkvpacked_func
|
||||
from transformers.modeling_utils import is_flash_attn_2_available
|
||||
from xformers.ops import memory_efficient_attention
|
||||
from functools import partial
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_qkvpacked_func
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
@ -134,7 +138,7 @@ class Attention(nn.Module):
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||
|
||||
if not self.qk_norm:
|
||||
if self.head_dim % 32 == 0:
|
||||
if self.head_dim % 32 == 0 and is_flash_attn_2_available():
|
||||
# flashattn的head_dim必须是32的倍数,SigLIP-SO400M无法使用flashattn
|
||||
x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
deterministic=self.deterministic)
|
||||
|
Loading…
Reference in New Issue
Block a user