This commit is contained in:
Saar Tochner 2025-02-26 17:31:25 +08:00 committed by GitHub
commit 53211ffd62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,29 @@ from timm.layers import (
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from transformers.modeling_utils import is_flash_attn_2_available
from functools import partial
try:
from xformers.ops import memory_efficient_attention
except ImportError:
warnings.warn(
"xformers not installed, using slow PyTorch implementation of memory_efficient_attention",
stacklevel=2,
)
def memory_efficient_attention(query, key, value, p):
""" This code is taken from https://facebookresearch.github.io/xformers/components/ops.html """
attn_bias = None
scale = 1.0 / query.shape[-1] ** 0.5
query = query * scale
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
attn = attn @ value
return attn.transpose(1, 2).contiguous()
if is_flash_attn_2_available():