mirror of
https://github.com/deepseek-ai/DeepSeek-VL2.git
synced 2025-02-22 13:49:00 -05:00
fix flash_attn import on old GPU
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
66ec91081c
commit
a60e8208bb
@ -12,11 +12,15 @@ from timm.layers import (
|
|||||||
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
|
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
|
||||||
)
|
)
|
||||||
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
from flash_attn import flash_attn_qkvpacked_func
|
from transformers.modeling_utils import is_flash_attn_2_available
|
||||||
from xformers.ops import memory_efficient_attention
|
from xformers.ops import memory_efficient_attention
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
if is_flash_attn_2_available():
|
||||||
|
from flash_attn import flash_attn_qkvpacked_func
|
||||||
|
|
||||||
|
|
||||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||||
@ -134,7 +138,7 @@ class Attention(nn.Module):
|
|||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
if not self.qk_norm:
|
if not self.qk_norm:
|
||||||
if self.head_dim % 32 == 0:
|
if self.head_dim % 32 == 0 and is_flash_attn_2_available():
|
||||||
# flashattn的head_dim必须是32的倍数,SigLIP-SO400M无法使用flashattn
|
# flashattn的head_dim必须是32的倍数,SigLIP-SO400M无法使用flashattn
|
||||||
x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0.,
|
x = flash_attn_qkvpacked_func(qkv, dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
deterministic=self.deterministic)
|
deterministic=self.deterministic)
|
||||||
|
Loading…
Reference in New Issue
Block a user