mirror of
https://github.com/deepseek-ai/DeepSeek-VL2.git
synced 2025-02-21 21:29:00 -05:00
Merge pull request #89 from CUHKSZzxy/optimize-deps
change xformers import location
This commit is contained in:
commit
3c2cd21c9f
@ -13,7 +13,6 @@ from timm.layers import (
|
||||
)
|
||||
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from transformers.modeling_utils import is_flash_attn_2_available
|
||||
from xformers.ops import memory_efficient_attention
|
||||
from functools import partial
|
||||
|
||||
|
||||
@ -134,6 +133,8 @@ class Attention(nn.Module):
|
||||
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user