support CPU inference

This commit is contained in:
Saar Tochiner 2025-02-01 22:30:25 +02:00
parent c74816ad22
commit 9bccb34c8e

View File

@ -13,8 +13,30 @@ from timm.layers import (
)
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from transformers.modeling_utils import is_flash_attn_2_available
from xformers.ops import memory_efficient_attention
from functools import partial
try:
from xformers.ops import memory_efficient_attention
except ImportError:
warnings.warn(
"xformers not installed, using slow PyTorch implementation of memory_efficient_attention",
stacklevel=2,
)
def memory_efficient_attention(query, key, value, p):
""" This code is taken from https://facebookresearch.github.io/xformers/components/ops.html """
attn_bias = None
scale = 1.0 / query.shape[-1] ** 0.5
query = query * scale
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
attn = attn @ value
return attn.transpose(1, 2).contiguous()
if is_flash_attn_2_available():