mirror of
https://github.com/deepseek-ai/DeepSeek-VL2.git
synced 2025-02-22 13:49:00 -05:00
support CPU inference
This commit is contained in:
parent
c74816ad22
commit
9bccb34c8e
@ -13,8 +13,30 @@ from timm.layers import (
|
||||
)
|
||||
from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
from transformers.modeling_utils import is_flash_attn_2_available
|
||||
from xformers.ops import memory_efficient_attention
|
||||
from functools import partial
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"xformers not installed, using slow PyTorch implementation of memory_efficient_attention",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def memory_efficient_attention(query, key, value, p):
|
||||
""" This code is taken from https://facebookresearch.github.io/xformers/components/ops.html """
|
||||
attn_bias = None
|
||||
scale = 1.0 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
attn = query @ key.transpose(-2, -1)
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
attn = F.dropout(attn, p)
|
||||
attn = attn @ value
|
||||
return attn.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
|
Loading…
Reference in New Issue
Block a user