Merge pull request #89 from CUHKSZzxy/optimize-deps

change xformers import location
This commit is contained in:
StevenLiuWen 2025-02-20 14:39:09 +08:00 committed by GitHub
commit 3c2cd21c9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,7 +13,6 @@ from timm.layers import (
) )
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from transformers.modeling_utils import is_flash_attn_2_available from transformers.modeling_utils import is_flash_attn_2_available
from xformers.ops import memory_efficient_attention
from functools import partial from functools import partial
@ -134,6 +133,8 @@ class Attention(nn.Module):
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity() self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
from xformers.ops import memory_efficient_attention
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)