From 9bccb34c8eb89906b0c36ed10594709c688b958d Mon Sep 17 00:00:00 2001 From: Saar Tochiner Date: Sat, 1 Feb 2025 22:30:25 +0200 Subject: [PATCH] support CPU inference --- deepseek_vl2/models/siglip_vit.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/deepseek_vl2/models/siglip_vit.py b/deepseek_vl2/models/siglip_vit.py index 67f30e8..1d61951 100644 --- a/deepseek_vl2/models/siglip_vit.py +++ b/deepseek_vl2/models/siglip_vit.py @@ -13,8 +13,30 @@ from timm.layers import ( ) from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv from transformers.modeling_utils import is_flash_attn_2_available -from xformers.ops import memory_efficient_attention from functools import partial +try: + from xformers.ops import memory_efficient_attention +except ImportError: + warnings.warn( + "xformers not installed, using slow PyTorch implementation of memory_efficient_attention", + stacklevel=2, + ) + + def memory_efficient_attention(query, key, value, p): + """ This code is taken from https://facebookresearch.github.io/xformers/components/ops.html """ + attn_bias = None + scale = 1.0 / query.shape[-1] ** 0.5 + query = query * scale + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + attn = query @ key.transpose(-2, -1) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + attn = F.dropout(attn, p) + attn = attn @ value + return attn.transpose(1, 2).contiguous() if is_flash_attn_2_available():