Gradio Demo Example, Incremental Prefilling and VLMEvalKit Support
This commit is contained in:
StevenLiuWen 2024-12-26 22:37:57 +08:00
parent 8bde1c1ae1
commit faf18023f2
38 changed files with 1369 additions and 168 deletions

1
.gitattributes vendored
View File

@ -6,3 +6,4 @@
*.jpeg binary
*.gif binary
*.pdf binary
*.ttc binary

225
README.md
View File

@ -69,6 +69,7 @@ Zhiyu Wu*, Xiaokang Chen*, Zizheng Pan*, Xingchao Liu*, Wen Liu**, Damai Dai, Hu
![](./images/vl2_teaser.jpeg)
## 2. Release
<b>2024-12-25</b>: Gradio Demo Example, Incremental Prefilling and VLMEvalKit Support.
<b>2024-12-13</b>: DeepSeek-VL2 family released, including <code>DeepSeek-VL2-tiny</code>, <code>DeepSeek-VL2-small</code>, <code>DeepSeek-VL2</code>.
## 3. Model Download
@ -96,7 +97,9 @@ On the basis of `Python >= 3.8` environment, install the necessary dependencies
pip install -e .
```
### Simple Inference Example
### Simple Inference Example with One Image
**Note: You may need 80GB GPU memory to run this script with deepseek-vl2-small and even larger for deepseek-vl2.**
```python
import torch
@ -107,7 +110,7 @@ from deepseek_vl2.utils.io import load_pil_images
# specify the path to the model
model_path = "deepseek-ai/deepseek-vl2-small"
model_path = "deepseek-ai/deepseek-vl2-tiny"
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
@ -119,23 +122,78 @@ conversation = [
{
"role": "<|User|>",
"content": "<image>\n<|ref|>The giraffe at the back.<|/ref|>.",
"images": ["./images/visual_grounding.jpeg"],
"images": ["./images/visual_grounding_1.jpeg"],
},
{"role": "<|Assistant|>", "content": ""},
]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True,
system_prompt=""
).to(vl_gpt.device)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# run the model to get the response
outputs = vl_gpt.language.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False)
print(f"{prepare_inputs['sft_format'][0]}", answer)
```
And the output is something like:
```
<|User|>: <image>
<|ref|>The giraffe at the back.<|/ref|>.
<|Assistant|>: <|ref|>The giraffe at the back.<|/ref|><|det|>[[580, 270, 999, 900]]<|/det|><end▁of▁sentence>
```
### Simple Inference Example with Multiple Images
**Note: You may need 80GB GPU memory to run this script with deepseek-vl2-small and even larger for deepseek-vl2.**
```python
import torch
from transformers import AutoModelForCausalLM
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
from deepseek_vl2.utils.io import load_pil_images
# specify the path to the model
model_path = "deepseek-ai/deepseek-vl2-tiny"
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
# multiple images/interleaved image-text
conversation_multi_images = [
conversation = [
{
"role": "<|User|>",
"content": "This is image_1: <image>\n"
"This is image_2: <image>\n"
"This is image_3: <image>\n If I am a vegetarian, what can I cook with these ingredients?",
"This is image_3: <image>\n Can you tell me what are in the images?",
"images": [
"images/multi_image_1.png",
"images/multi_image_2.jpg",
"images/multi_image_3.jpg",
"images/multi_image_1.jpeg",
"images/multi_image_2.jpeg",
"images/multi_image_3.jpeg",
],
},
{"role": "<|Assistant|>", "content": ""}
@ -169,12 +227,151 @@ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False)
print(f"{prepare_inputs['sft_format'][0]}", answer)
```
### Gradio Demo (TODO)
And the output is something like:
```
<|User|>: This is image_1: <image>
This is image_2: <image>
This is image_3: <image>
Can you tell me what are in the images?
<|Assistant|>: The images show three different types of vegetables. Image_1 features carrots, which are orange with green tops. Image_2 displays corn cobs, which are yellow with green husks. Image_3 contains raw pork ribs, which are pinkish-red with some marbling.<end▁of▁sentence>
```
### Simple Inference Example with Incremental Prefilling
**Note: We use incremental prefilling to inference within 40GB GPU using deepseek-vl2-small.**
```python
import torch
from transformers import AutoModelForCausalLM
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
from deepseek_vl2.utils.io import load_pil_images
### Demo
This figure present some examples of DeepSeek-VL2.
![](./images/github_demo.png)
# specify the path to the model
model_path = "deepseek-ai/deepseek-vl2-small"
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
# multiple images/interleaved image-text
conversation = [
{
"role": "<|User|>",
"content": "This is image_1: <image>\n"
"This is image_2: <image>\n"
"This is image_3: <image>\n Can you tell me what are in the images?",
"images": [
"images/multi_image_1.jpeg",
"images/multi_image_2.jpeg",
"images/multi_image_3.jpeg",
],
},
{"role": "<|Assistant|>", "content": ""}
]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True,
system_prompt=""
).to(vl_gpt.device)
with torch.no_grad():
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# incremental_prefilling when using 40G GPU for vl2-small
inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
input_ids=prepare_inputs.input_ids,
images=prepare_inputs.images,
images_seq_mask=prepare_inputs.images_seq_mask,
images_spatial_crop=prepare_inputs.images_spatial_crop,
attention_mask=prepare_inputs.attention_mask,
chunk_size=512 # prefilling size
)
# run the model to get the response
outputs = vl_gpt.generate(
inputs_embeds=inputs_embeds,
input_ids=prepare_inputs.input_ids,
images=prepare_inputs.images,
images_seq_mask=prepare_inputs.images_seq_mask,
images_spatial_crop=prepare_inputs.images_spatial_crop,
attention_mask=prepare_inputs.attention_mask,
past_key_values=past_key_values,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True,
)
answer = tokenizer.decode(outputs[0][len(prepare_inputs.input_ids[0]):].cpu().tolist(), skip_special_tokens=False)
print(f"{prepare_inputs['sft_format'][0]}", answer)
```
And the output is something like:
```
<|User|>: This is image_1: <image>
This is image_2: <image>
This is image_3: <image>
Can you tell me what are in the images?
<|Assistant|>: The first image contains carrots. The second image contains corn. The third image contains meat.<end▁of▁sentence>
```
### Full Inference Example
```shell
# without incremental prefilling
CUDA_VISIBLE_DEVICES=0 python inference.py --model_patn "deepseek-ai/deepseek-vl2"
# with incremental prefilling, when using 40G GPU for vl2-small
CUDA_VISIBLE_DEVICES=0 python inference.py --model_patn "deepseek-ai/deepseek-vl2-small" --chunck_size 512
```
### Gradio Demo
* Install the necessary dependencies:
```shell
pip install -e .[gradio]
```
* then run the following command:
```shell
# vl2-tiny, 3.37B-MoE in total, activated 1B, can be run on a single GPU < 40GB
CUDA_VISIBLE_DEVICES=2 python web_demo.py \
--model_name "deepseek-ai/deepseek-vl2-tiny" \
--port 37914
# vl2-small, 16.1B-MoE in total, activated 2.4B
# If run on A100 40GB GPU, you need to set the `--chunk_size 512` for incremental prefilling for saving memory and it might be slow.
# If run on > 40GB GPU, you can ignore the `--chunk_size 512` for faster response.
CUDA_VISIBLE_DEVICES=2 python web_demo.py \
--model_name "deepseek-ai/deepseek-vl2-small" \
--port 37914 \
--chunk_size 512
# # vl27.5-MoE in total, activated 4.2B
CUDA_VISIBLE_DEVICES=2 python web_demo.py \
--model_name "deepseek-ai/deepseek-vl2" \
--port 37914
```
* **Important**: This is a basic and native demo implementation without any deployment optimizations, which may result in slower performance. For production environments, consider using optimized deployment solutions, such as vllm, sglang, lmdeploy, etc. These optimizations will help achieve faster response times and better cost efficiency.
## 5. License
@ -184,13 +381,13 @@ This code repository is licensed under [MIT License](./LICENSE-CODE). The use of
```
@misc{wu2024deepseekvl2mixtureofexpertsvisionlanguagemodels,
title={DeepSeek-VL2: Mixture-of-Experts Vision-Language Models for Advanced Multimodal Understanding},
title={DeepSeek-VL2: Mixture-of-Experts Vision-Language Models for Advanced Multimodal Understanding},
author={Zhiyu Wu and Xiaokang Chen and Zizheng Pan and Xingchao Liu and Wen Liu and Damai Dai and Huazuo Gao and Yiyang Ma and Chengyue Wu and Bingxuan Wang and Zhenda Xie and Yu Wu and Kai Hu and Jiawei Wang and Yaofeng Sun and Yukun Li and Yishi Piao and Kang Guan and Aixin Liu and Xin Xie and Yuxiang You and Kai Dong and Xingkai Yu and Haowei Zhang and Liang Zhao and Yisong Wang and Chong Ruan},
year={2024},
eprint={2412.10302},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2412.10302},
url={https://arxiv.org/abs/2412.10302},
}
```

View File

@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch DeepSeek model."""
""" PyTorch DeepSeek model and compatible with both DeepSeekV2 and DeepSeekV3"""
import math
import warnings
from typing import List, Optional, Tuple, Union
@ -27,16 +27,13 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import torch.distributed as dist
from einops import repeat
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
)
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2
@ -63,12 +60,10 @@ from transformers.utils.import_utils import is_torch_fx_available
from .configuration_deepseek import DeepseekV2Config
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
@ -77,7 +72,6 @@ if is_torch_fx_available():
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DeepseekV2Config"
@ -869,17 +863,10 @@ class DeepseekV2Attention(nn.Module):
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
kv_seq_len = value_states.shape[-2]
kv_seq_len = k_pe.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
@ -888,27 +875,23 @@ class DeepseekV2Attention(nn.Module):
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
compressed_kv = compressed_kv.unsqueeze(1)
k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
compressed_kv = compressed_kv.squeeze(1)
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale
)
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
q_nope = torch.matmul(q_nope, q_absorb)
attn_weights = (torch.matmul(q_pe, k_pe.mT) +
torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
@ -925,11 +908,13 @@ class DeepseekV2Attention(nn.Module):
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
).to(q_pe.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
attn_output = torch.matmul(attn_output, out_absorb.mT)
if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
raise ValueError(
@ -1034,6 +1019,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
if self.q_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
# TODO: support compressed_kv for kv_cache (instead of key_states, value_states) in flash_attention version
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
@ -1494,6 +1480,7 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
@ -1668,17 +1655,18 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
@ -1730,6 +1718,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position
)
hidden_states = outputs[0]
@ -1762,13 +1751,14 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
@ -1780,13 +1770,10 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if (
attention_mask is not None
and attention_mask.shape[1] > input_ids.shape[1]
):
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
@ -1795,9 +1782,9 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
@ -1807,17 +1794,35 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
position_ids = position_ids[:, -input_ids.shape[1]:]
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"position_ids": position_ids,
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
@ -1871,17 +1876,17 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
@add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -1921,7 +1926,7 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
else:
if input_ids is not None:
sequence_lengths = (
torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
).to(logits.device)
else:
sequence_lengths = -1
@ -1937,7 +1942,7 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:

View File

@ -1,4 +1,8 @@
from attrdict import AttrDict
from dataclasses import dataclass
import logging
import gc
from einops import rearrange, repeat
from typing import Optional, List, Tuple, Callable, Union
@ -6,19 +10,27 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
from transformers.modeling_outputs import ModelOutput
from transformers.configuration_utils import PretrainedConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedModel, GenerationConfig, LogitsProcessorList, StoppingCriteriaList,
PreTrainedModel
)
from transformers.generation.utils import GenerateOutput
from transformers.utils import logging
from .siglip_vit import VisionTransformer
from .configuration_deepseek import DeepseekV2Config
from .modeling_deepseek import DeepseekV2ForCausalLM
logger = logging.get_logger(__name__)
class MlpProjector(nn.Module):
def __init__(self, cfg):
@ -181,6 +193,45 @@ class MlpProjectorConfig(PretrainedConfig):
super().__init__(**kwargs)
@dataclass
class DeepSeekVLV2CausalLMOutputWithPast(ModelOutput):
"""
Base class for DeepSeek-VL2 causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
rope_deltas: Optional[torch.LongTensor] = None
class DeepseekVLV2Config(PretrainedConfig):
model_type = "deepseek_vl_v2"
vision_config: VisionEncoderConfig
@ -229,6 +280,8 @@ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
def __init__(self, config: DeepseekVLV2Config):
super().__init__(config)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# ----------- vision encoder ------------
vision_config = config.vision_config
self.vision = VisionTransformer(
@ -283,8 +336,8 @@ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
images: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images: Optional[torch.FloatTensor] = None,
images_seq_mask: Optional[torch.LongTensor] = None,
images_spatial_crop: Optional[torch.LongTensor] = None,
**ignore_kwargs
):
@ -423,48 +476,222 @@ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
return input_embeds
def generate(
@torch.no_grad()
def incremental_prefilling(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
images: Optional[torch.FloatTensor] = None,
images_seq_mask: Optional[torch.LongTensor] = None,
images_spatial_crop: Optional[torch.LongTensor] = None,
chunk_size: int = 1024
):
if inputs_embeds is None:
inputs_embeds = self.prepare_inputs_embeds(
input_ids=input_ids,
images=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
)
del images
del images_seq_mask
del images_spatial_crop
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
self._clear_cuda_cache()
bzs, seq_len, _ = inputs_embeds.shape
past_key_values = None
# remain the last token for the next forward
prefilling_len = seq_len - 1
for i in range(0, prefilling_len, chunk_size):
chunk_start = i
chunk_end = min(i + chunk_size, prefilling_len)
chunk_inputs_embeds = inputs_embeds[:, chunk_start: chunk_end]
chunk_attention_mask = attention_mask[:, 0: chunk_end]
# print(f"start = {chunk_start}, end = {chunk_end}, prefilling_len = {prefilling_len}, seq_len = {seq_len}")
# compute position_ids
if past_key_values is not None:
position_ids = torch.arange(
chunk_start,
chunk_end,
dtype=torch.long,
device=inputs_embeds.device
).unsqueeze(0)
past_key_values = self._move_past_key_values_to_gpu(past_key_values, inputs_embeds.device)
else:
position_ids = None
# chunk-forward
with torch.no_grad():
outputs = self.forward(
inputs_embeds=chunk_inputs_embeds,
attention_mask=chunk_attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
use_cache=True,
)
# update past_key_values
past_key_values = outputs.past_key_values
past_key_values = self._move_past_key_values_to_cpu(past_key_values)
del outputs, position_ids
self._clear_cuda_cache()
prefilling_key_values = []
for layer_past in past_key_values:
prefilling_key_values.append(
(
layer_past[0][:, :, 0: prefilling_len, ...].to(inputs_embeds.device),
layer_past[1][:, :, 0: prefilling_len, ...].to(inputs_embeds.device),
)
)
return inputs_embeds, prefilling_key_values
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
images: Optional[torch.FloatTensor] = None,
images_seq_mask: Optional[torch.LongTensor] = None,
images_spatial_crop: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.prepare_inputs_embeds(
input_ids=input_ids,
images=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# print(inputs_embeds.shape)
outputs = self.language.forward(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position
)
self._clear_cuda_cache()
return outputs
def _clear_cuda_cache(self):
"""clear CUDA memory cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _move_past_key_values_to_cpu(self, past_key_values):
# print(f"past_key_values -> cpu")
if past_key_values is None:
return None
return tuple(tuple(t.cpu() for t in layer) for layer in past_key_values)
def _move_past_key_values_to_gpu(self, past_key_values, device="cuda:0"):
# print(f"past_key_values -> gpu")
if past_key_values is None:
return None
return tuple(tuple(t.to(device) for t in layer) for layer in past_key_values)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
images: Optional[torch.FloatTensor] = None,
images_seq_mask: Optional[torch.LongTensor] = None,
images_spatial_crop: Optional[torch.LongTensor] = None,
attention_mask=None,
cache_position=None,
pixel_values=None,
image_sizes=None,
num_logits_to_keep=None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. Beam-search decoding
is controlled by the `num_beams` parameter and the `num_return_sequences` parameter.
Parameters:
- `inputs` (optional) -- `torch.LongTensor` of shape `(batch, sequence_length)`:
The sequence used as a prompt for the generation. If `None`, generate for the model's prompt.
- `generation_config` (optional) -- `GenerationConfig`:
The generation config of the model.
- `logits_processor` (optional) -- `LogitsProcessorList`:
A list of instances of :class:`~transform
"""
return self.language.generate(
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = self.language.prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
**kwargs,
)
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
cache_position = model_inputs["cache_position"]
if cache_position[0] == 0:
model_inputs["images"] = images
model_inputs["images_seq_mask"] = images_seq_mask
model_inputs["images_spatial_crop"] = images_spatial_crop
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
),
)
return reordered_past
AutoConfig.register("vision", VisionEncoderConfig)
AutoConfig.register("mlp_projector", MlpProjectorConfig)

View File

@ -559,7 +559,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
for j in range(0, best_width, self.image_size):
images_list.append(
self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size))))
"""record height / width crop num"""
num_width_tiles, num_height_tiles = best_width // self.image_size, best_height // self.image_size
images_spatial_crop.append([num_width_tiles, num_height_tiles])

View File

@ -21,7 +21,7 @@
import gradio as gr
title = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with DeepSeek-VL2 </h1>"""
description_top = """"""
description_top = """Special Tokens: `<image>`, Visual Grounding: `<|ref|>{query}<|/ref|>`, Grounding Conversation: `<|grounding|>{question}`"""
description = """"""
CONCURRENT_COUNT = 1
MAX_EVENTS = 10

View File

@ -242,7 +242,9 @@ def pil_to_base64(
alt: str = "user upload image",
resize: bool = True,
max_size: int = MAX_IMAGE_SIZE,
min_size: int = MIN_IMAGE_SIZE
min_size: int = MIN_IMAGE_SIZE,
format: str = "JPEG",
quality: int = 95
) -> str:
if resize:
@ -258,15 +260,16 @@ def pil_to_base64(
image = image.resize((W, H))
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
image.save(buffered, format=format, quality=quality)
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{alt}" />'
return img_str
def parse_ref_bbox(response, image):
def parse_ref_bbox(response, image: Image.Image):
try:
image = image.copy()
image_h, image_w = image.size
draw = ImageDraw.Draw(image)
@ -275,7 +278,7 @@ def parse_ref_bbox(response, image):
assert len(ref) == len(bbox)
if len(ref) == 0:
return
return None
boxes, labels = [], []
for box, label in zip(bbox, ref):
@ -301,9 +304,30 @@ def parse_ref_bbox(response, image):
text_x = box[0]
text_y = box[1] - 20
text_color = box_color
font = ImageFont.truetype('./deepseek_vl2/serve/assets/simsun.ttc', size=20)
font = ImageFont.truetype("deepseek_vl2/serve/assets/simsun.ttc", size=20)
draw.text((text_x, text_y), label, font=font, fill=text_color)
# print(f"boxes = {boxes}, labels = {labels}, re-render = {image}")
return image
except:
return
return None
def display_example(image_list):
images_html = ""
for i, img_path in enumerate(image_list):
image = Image.open(img_path)
buffered = io.BytesIO()
image.save(buffered, format="PNG", quality=100)
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{img_path}" style="height:80px; margin-right: 10px;" />'
images_html += img_str
result_html = f"""
<div style="display: flex; align-items: center; margin-bottom: 10px;">
<div style="flex: 1; margin-right: 10px;">{images_html}</div>
</div>
"""
return result_html

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 153 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 266 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 190 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 56 KiB

View File

@ -47,24 +47,27 @@ def load_model(model_path, dtype=torch.bfloat16):
def convert_conversation_to_prompts(conversation: Conversation):
conv_prompts = []
pil_images = []
last_image = None
messages = conversation.messages
for i in range(0, len(messages), 2):
if isinstance(messages[i][1], tuple):
text, images = messages[i][1]
last_image = images[-1]
else:
text, images = messages[i][1], []
pil_images.extend(images)
prompt = {
"role": messages[i][0],
"content": text,
"images": images
}
response = {"role": messages[i + 1][0], "content": messages[i + 1][1]}
conv_prompts.extend([prompt, response])
return conv_prompts, pil_images
return conv_prompts, last_image
class StoppingCriteriaSub(StoppingCriteria):
@ -86,8 +89,7 @@ class StoppingCriteriaSub(StoppingCriteria):
@torch.inference_mode()
def deepseek_generate(
conv_prompts: list,
pil_images: list,
conversations: list,
vl_gpt: torch.nn.Module,
vl_chat_processor: DeepseekVLV2Processor,
tokenizer: transformers.PreTrainedTokenizer,
@ -95,11 +97,17 @@ def deepseek_generate(
max_length: int = 256,
temperature: float = 1.0,
top_p: float = 1.0,
repetition_penalty=1.1,
repetition_penalty: float = 1.1,
chunk_size: int = -1
):
pil_images = []
for message in conversations:
if "images" not in message:
continue
pil_images.extend(message["images"])
prepare_inputs = vl_chat_processor.__call__(
conversations=conv_prompts,
conversations=conversations,
images=pil_images,
inference_mode=True,
force_batchify=True,
@ -110,11 +118,12 @@ def deepseek_generate(
vl_gpt,
tokenizer,
prepare_inputs,
max_length,
temperature,
repetition_penalty,
top_p,
stop_words,
max_gen_len=max_length,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
stop_words=stop_words,
chunk_size=chunk_size
)
@ -128,11 +137,10 @@ def generate(
repetition_penalty=1.1,
top_p: float = 0.95,
stop_words: List[str] = [],
chunk_size: int = -1
):
"""Stream the text output from the multimodality model with prompt and image inputs."""
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
streamer = TextIteratorStreamer(tokenizer)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
stop_words_ids = [
torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words
@ -141,9 +149,27 @@ def generate(
[StoppingCriteriaSub(stops=stop_words_ids)]
)
if chunk_size != -1:
inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
input_ids=prepare_inputs.input_ids,
images=prepare_inputs.images,
images_seq_mask=prepare_inputs.images_seq_mask,
images_spatial_crop=prepare_inputs.images_spatial_crop,
attention_mask=prepare_inputs.attention_mask,
chunk_size=chunk_size
)
else:
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
past_key_values = None
generation_config = dict(
inputs_embeds=inputs_embeds,
input_ids=prepare_inputs.input_ids,
images=prepare_inputs.images,
images_seq_mask=prepare_inputs.images_seq_mask,
images_spatial_crop=prepare_inputs.images_spatial_crop,
attention_mask=prepare_inputs.attention_mask,
past_key_values=past_key_values,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 356 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 418 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 363 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 266 KiB

BIN
images/icl_vg_2.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 845 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

BIN
images/monday.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

BIN
images/multi_image_1.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 855 KiB

View File

Before

Width:  |  Height:  |  Size: 55 KiB

After

Width:  |  Height:  |  Size: 55 KiB

View File

Before

Width:  |  Height:  |  Size: 140 KiB

After

Width:  |  Height:  |  Size: 140 KiB

BIN
images/qr.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 512 KiB

BIN
images/sample.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.4 MiB

BIN
images/vg_2.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

View File

Before

Width:  |  Height:  |  Size: 217 KiB

After

Width:  |  Height:  |  Size: 217 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

BIN
images/vqa_1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

View File

@ -21,7 +21,6 @@ from argparse import ArgumentParser
from typing import List, Dict
import torch
from transformers import AutoModelForCausalLM
import PIL.Image
from deepseek_vl2.models import DeepseekVLV2ForCausalLM, DeepseekVLV2Processor
@ -81,14 +80,37 @@ def main(args):
conversation = [
{
"role": "<|User|>",
"content": "<image>\n<|ref|>The giraffe at the back.<|/ref|>.",
"images": ["./images/visual_grounding.jpeg"],
"content": "<image>\n<image>\n<|grounding|>In the first image, an object within the red rectangle is marked. Locate the object of the same category in the second image.",
"images": [
"images/incontext_visual_grounding_1.jpeg",
"images/icl_vg_2.jpeg"
],
},
{"role": "<|Assistant|>", "content": ""},
]
# conversation = [
# {
# "role": "<|User|>",
# "content": "<image>\n<|ref|>The giraffe at the back.<|/ref|>.",
# "images": ["./images/visual_grounding_1.jpeg"],
# },
# {"role": "<|Assistant|>", "content": ""},
# ]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
print(f"len(pil_images) = {len(pil_images)}")
# input_ids = batched_input_ids,
# attention_mask = batched_attention_mask,
# labels = batched_labels,
# images_tiles = batched_images,
# images_seq_mask = batched_images_seq_mask,
# images_spatial_crop = batched_images_spatial_crop,
# sft_format = batched_sft_format,
# seq_lens = seq_lens
prepare_inputs = vl_chat_processor.__call__(
conversations=conversation,
images=pil_images,
@ -96,34 +118,59 @@ def main(args):
system_prompt=""
).to(vl_gpt.device, dtype=dtype)
# for key in prepare_inputs.keys():
# value = prepare_inputs[key]
# if isinstance(value, list):
# print(key, len(value), type(value))
# elif isinstance(value, torch.Tensor):
# print(key, value.shape, type(value))
with torch.no_grad():
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# incremental_prefilling when using 40G GPU for vl2-small
inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
input_ids=prepare_inputs.input_ids,
images=prepare_inputs.images,
images_seq_mask=prepare_inputs.images_seq_mask,
images_spatial_crop=prepare_inputs.images_spatial_crop,
attention_mask=prepare_inputs.attention_mask,
chunk_size=args.chunk_size
)
# run the model to get the response
outputs = vl_gpt.generate(
# inputs_embeds=inputs_embeds[:, -1:],
# input_ids=prepare_inputs.input_ids[:, -1:],
inputs_embeds=inputs_embeds,
input_ids=prepare_inputs.input_ids,
images=prepare_inputs.images,
images_seq_mask=prepare_inputs.images_seq_mask,
images_spatial_crop=prepare_inputs.images_spatial_crop,
attention_mask=prepare_inputs.attention_mask,
past_key_values=past_key_values,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=1024,
max_new_tokens=512,
do_sample=False,
# do_sample=False,
# repetition_penalty=1.1,
# do_sample=True,
# temperature=1.0,
# top_p=0.9,
# repetition_penalty=1.1,
do_sample=True,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.1,
use_cache=True,
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False)
answer = tokenizer.decode(outputs[0][len(prepare_inputs.input_ids[0]):].cpu().tolist(), skip_special_tokens=False)
print(f"{prepare_inputs['sft_format'][0]}", answer)
vg_image = parse_ref_bbox(answer, image=pil_images[0])
vg_image = parse_ref_bbox(answer, image=pil_images[-1])
if vg_image is not None:
vg_image.save("./vg.jpg", format="JPEG", quality=85)
@ -131,7 +178,8 @@ def main(args):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--model_path", type=str, required=True,
default="deepseek-ai/deepseek-vl2-27b-moe",
default="deepseek-ai/deepseek-vl2",
help="model name or local path to the model")
parser.add_argument("--chunk_size", type=int, default=512, help="chunk size for the model for prefiiling")
args = parser.parse_args()
main(args)

674
web_demo.py Executable file
View File

@ -0,0 +1,674 @@
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# -*- coding:utf-8 -*-
from argparse import ArgumentParser
import io
import sys
import base64
from PIL import Image
import gradio as gr
import torch
from deepseek_vl2.serve.app_modules.gradio_utils import (
cancel_outputing,
delete_last_conversation,
reset_state,
reset_textbox,
wrap_gen_fn,
)
from deepseek_vl2.serve.app_modules.overwrites import reload_javascript
from deepseek_vl2.serve.app_modules.presets import (
CONCURRENT_COUNT,
MAX_EVENTS,
description,
description_top,
title
)
from deepseek_vl2.serve.app_modules.utils import (
configure_logger,
is_variable_assigned,
strip_stop_words,
parse_ref_bbox,
pil_to_base64,
display_example
)
from deepseek_vl2.serve.inference import (
convert_conversation_to_prompts,
deepseek_generate,
load_model,
)
from deepseek_vl2.models.conversation import SeparatorStyle
logger = configure_logger()
MODELS = [
"DeepSeek-VL2-tiny",
"DeepSeek-VL2-small",
"DeepSeek-VL2",
"deepseek-ai/deepseek-vl2-tiny",
"deepseek-ai/deepseek-vl2-small",
"deepseek-ai/deepseek-vl2",
]
DEPLOY_MODELS = dict()
IMAGE_TOKEN = "<image>"
examples_list = [
# visual grounding - 1
[
["images/visual_grounding_1.jpeg"],
"<|ref|>The giraffe at the back.<|/ref|>",
],
# visual grounding - 2
[
["images/visual_grounding_2.jpg"],
"找到<|ref|>淡定姐<|/ref|>",
],
# visual grounding - 3
[
["images/visual_grounding_3.png"],
"Find all the <|ref|>Watermelon slices<|/ref|>",
],
# grounding conversation
[
["images/grounding_conversation_1.jpeg"],
"<|grounding|>I want to throw out the trash now, what should I do?",
],
# in-context visual grounding
[
[
"images/incontext_visual_grounding_1.jpeg",
"images/icl_vg_2.jpeg"
],
"<|grounding|>In the first image, an object within the red rectangle is marked. Locate the object of the same category in the second image."
],
# vqa
[
["images/vqa_1.jpg"],
"Describe each stage of this image in detail",
],
# multi-images
[
[
"images/multi_image_1.jpeg",
"images/mi_2.jpeg",
"images/mi_3.jpeg"
],
"能帮我用这几个食材做一道菜吗?",
]
]
def fetch_model(model_name: str, dtype=torch.bfloat16):
global args, DEPLOY_MODELS
if args.local_path:
model_path = args.local_path
else:
model_path = model_name
if model_name in DEPLOY_MODELS:
model_info = DEPLOY_MODELS[model_name]
print(f"{model_name} has been loaded.")
else:
print(f"{model_name} is loading...")
DEPLOY_MODELS[model_name] = load_model(model_path, dtype=dtype)
print(f"Load {model_name} successfully...")
model_info = DEPLOY_MODELS[model_name]
return model_info
def generate_prompt_with_history(
text, images, history, vl_chat_processor, tokenizer, max_length=2048
):
"""
Generate a prompt with history for the deepseek application.
Args:
text (str): The text prompt.
images (list[PIL.Image.Image]): The image prompt.
history (list): List of previous conversation messages.
tokenizer: The tokenizer used for encoding the prompt.
max_length (int): The maximum length of the prompt.
Returns:
tuple: A tuple containing the generated prompt, image list, conversation, and conversation copy. If the prompt could not be generated within the max_length limit, returns None.
"""
global IMAGE_TOKEN
sft_format = "deepseek"
user_role_ind = 0
bot_role_ind = 1
# Initialize conversation
conversation = vl_chat_processor.new_chat_template()
if history:
conversation.messages = history
if images is not None and len(images) > 0:
num_image_tags = text.count(IMAGE_TOKEN)
num_images = len(images)
if num_images > num_image_tags:
pad_image_tags = num_images - num_image_tags
image_tokens = "\n".join([IMAGE_TOKEN] * pad_image_tags)
# append the <image> in a new line after the text prompt
text = image_tokens + "\n" + text
elif num_images < num_image_tags:
remove_image_tags = num_image_tags - num_images
text = text.replace(IMAGE_TOKEN, "", remove_image_tags)
# print(f"prompt = {text}, len(images) = {len(images)}")
text = (text, images)
conversation.append_message(conversation.roles[user_role_ind], text)
conversation.append_message(conversation.roles[bot_role_ind], "")
# Create a copy of the conversation to avoid history truncation in the UI
conversation_copy = conversation.copy()
logger.info("=" * 80)
logger.info(get_prompt(conversation))
rounds = len(conversation.messages) // 2
for _ in range(rounds):
current_prompt = get_prompt(conversation)
current_prompt = (
current_prompt.replace("</s>", "")
if sft_format == "deepseek"
else current_prompt
)
if torch.tensor(tokenizer.encode(current_prompt)).size(-1) <= max_length:
return conversation_copy
if len(conversation.messages) % 2 != 0:
gr.Error("The messages between user and assistant are not paired.")
return
try:
for _ in range(2): # pop out two messages in a row
conversation.messages.pop(0)
except IndexError:
gr.Error("Input text processing failed, unable to respond in this round.")
return None
gr.Error("Prompt could not be generated within max_length limit.")
return None
def to_gradio_chatbot(conv):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(conv.messages[conv.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
msg, images = msg
if isinstance(images, list):
for j, image in enumerate(images):
if isinstance(image, str):
with open(image, "rb") as f:
data = f.read()
img_b64_str = base64.b64encode(data).decode()
image_str = (f'<img src="data:image/png;base64,{img_b64_str}" '
f'alt="user upload image" style="max-width: 300px; height: auto;" />')
else:
image_str = pil_to_base64(image, f"user upload image_{j}", max_size=800, min_size=400)
# replace the <image> tag in the message
msg = msg.replace(IMAGE_TOKEN, image_str, 1)
else:
pass
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def to_gradio_history(conv):
"""Convert the conversation to gradio history state."""
return conv.messages[conv.offset:]
def get_prompt(conv) -> str:
"""Get the prompt for generation."""
system_prompt = conv.system_template.format(system_message=conv.system_message)
if conv.sep_style == SeparatorStyle.DeepSeek:
seps = [conv.sep, conv.sep2]
if system_prompt == "" or system_prompt is None:
ret = ""
else:
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(conv.messages):
if message:
if type(message) is tuple: # multimodal message
message, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
return conv.get_prompt()
def transfer_input(input_text, input_images):
print("transferring input text and input image")
return (
input_text,
input_images,
gr.update(value=""),
gr.update(value=None),
gr.Button(visible=True)
)
@wrap_gen_fn
def predict(
text,
images,
chatbot,
history,
top_p,
temperature,
repetition_penalty,
max_length_tokens,
max_context_length_tokens,
model_select_dropdown,
):
"""
Function to predict the response based on the user's input and selected model.
Parameters:
user_text (str): The input text from the user.
user_image (str): The input image from the user.
chatbot (str): The chatbot's name.
history (str): The history of the chat.
top_p (float): The top-p parameter for the model.
temperature (float): The temperature parameter for the model.
max_length_tokens (int): The maximum length of tokens for the model.
max_context_length_tokens (int): The maximum length of context tokens for the model.
model_select_dropdown (str): The selected model from the dropdown.
Returns:
generator: A generator that yields the chatbot outputs, history, and status.
"""
print("running the prediction function")
try:
tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_select_dropdown)
if text == "":
yield chatbot, history, "Empty context."
return
except KeyError:
yield [[text, "No Model Found"]], [], "No Model Found"
return
if images is None:
images = []
# load images
pil_images = []
for img_or_file in images:
try:
# load as pil image
if isinstance(images, Image.Image):
pil_images.append(img_or_file)
else:
image = Image.open(img_or_file.name).convert("RGB")
pil_images.append(image)
except Exception as e:
print(f"Error loading image: {e}")
conversation = generate_prompt_with_history(
text,
pil_images,
history,
vl_chat_processor,
tokenizer,
max_length=max_context_length_tokens,
)
all_conv, last_image = convert_conversation_to_prompts(conversation)
stop_words = conversation.stop_str
gradio_chatbot_output = to_gradio_chatbot(conversation)
full_response = ""
with torch.no_grad():
for x in deepseek_generate(
conversations=all_conv,
vl_gpt=vl_gpt,
vl_chat_processor=vl_chat_processor,
tokenizer=tokenizer,
stop_words=stop_words,
max_length=max_length_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
chunk_size=args.chunk_size
):
full_response += x
response = strip_stop_words(full_response, stop_words)
conversation.update_last_message(response)
gradio_chatbot_output[-1][1] = response
# sys.stdout.write(x)
# sys.stdout.flush()
yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
if last_image is not None:
# TODO always render the last image's visual grounding image
vg_image = parse_ref_bbox(response, last_image)
if vg_image is not None:
vg_base64 = pil_to_base64(vg_image, f"vg", max_size=800, min_size=400)
gradio_chatbot_output[-1][1] += vg_base64
yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
print("flushed result to gradio")
torch.cuda.empty_cache()
if is_variable_assigned("x"):
print(f"{model_select_dropdown}:\n{text}\n{'-' * 80}\n{x}\n{'=' * 80}")
print(
f"temperature: {temperature}, "
f"top_p: {top_p}, "
f"repetition_penalty: {repetition_penalty}, "
f"max_length_tokens: {max_length_tokens}"
)
yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success"
# @wrap_gen_fn
def retry(
text,
images,
chatbot,
history,
top_p,
temperature,
repetition_penalty,
max_length_tokens,
max_context_length_tokens,
model_select_dropdown,
):
if len(history) == 0:
yield (chatbot, history, "Empty context")
return
chatbot.pop()
history.pop()
text = history.pop()[-1]
if type(text) is tuple:
text, image = text
yield from predict(
text,
images,
chatbot,
history,
top_p,
temperature,
repetition_penalty,
max_length_tokens,
max_context_length_tokens,
model_select_dropdown,
args.chunk_size
)
def preview_images(files):
if files is None:
return []
image_paths = []
for file in files:
# 使用 file.name 获取文件路径
# image = Image.open(file.name)
image_paths.append(file.name)
return image_paths # 返回所有图片路径,用于预览
def build_demo(args):
# fetch model
if not args.lazy_load:
fetch_model(args.model_name)
with open("deepseek_vl2/serve/assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
with gr.Blocks(theme=gr.themes.Soft()) as demo:
history = gr.State([])
input_text = gr.State()
input_images = gr.State()
with gr.Row():
gr.HTML(title)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(description_top)
with gr.Row(equal_height=True):
with gr.Column(scale=4):
with gr.Row():
chatbot = gr.Chatbot(
elem_id="deepseek_chatbot",
show_share_button=True,
bubble_full_width=False,
height=600,
)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(
show_label=False, placeholder="Enter text", container=False
)
with gr.Column(
min_width=70,
):
submitBtn = gr.Button("Send")
with gr.Column(
min_width=70,
):
cancelBtn = gr.Button("Stop")
with gr.Row():
emptyBtn = gr.Button(
"🧹 New Conversation",
)
retryBtn = gr.Button("🔄 Regenerate")
delLastBtn = gr.Button("🗑️ Remove Last Turn")
with gr.Column():
upload_images = gr.Files(file_types=["image"], show_label=True)
gallery = gr.Gallery(columns=[3], height="200px", show_label=True)
upload_images.change(preview_images, inputs=upload_images, outputs=gallery)
with gr.Tab(label="Parameter Setting") as parameter_row:
top_p = gr.Slider(
minimum=-0,
maximum=1.0,
value=0.9,
step=0.05,
interactive=True,
label="Top-p",
)
temperature = gr.Slider(
minimum=0,
maximum=1.0,
value=0.1,
step=0.1,
interactive=True,
label="Temperature",
)
repetition_penalty = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.1,
step=0.1,
interactive=True,
label="Repetition penalty",
)
max_length_tokens = gr.Slider(
minimum=0,
maximum=4096,
value=2048,
step=8,
interactive=True,
label="Max Generation Tokens",
)
max_context_length_tokens = gr.Slider(
minimum=0,
maximum=8192,
value=4096,
step=128,
interactive=True,
label="Max History Tokens",
)
model_select_dropdown = gr.Dropdown(
label="Select Models",
choices=[args.model_name],
multiselect=False,
value=args.model_name,
interactive=True,
)
# show images, but not visible
show_images = gr.HTML(visible=False)
# show_images = gr.Image(type="pil", interactive=False, visible=False)
def format_examples(examples_list):
examples = []
for images, texts in examples_list:
examples.append([images, display_example(images), texts])
return examples
gr.Examples(
examples=format_examples(examples_list),
inputs=[upload_images, show_images, text_box],
)
gr.Markdown(description)
input_widgets = [
input_text,
input_images,
chatbot,
history,
top_p,
temperature,
repetition_penalty,
max_length_tokens,
max_context_length_tokens,
model_select_dropdown,
]
output_widgets = [chatbot, history, status_display]
transfer_input_args = dict(
fn=transfer_input,
inputs=[text_box, upload_images],
outputs=[input_text, input_images, text_box, upload_images, submitBtn],
show_progress=True,
)
predict_args = dict(
fn=predict,
inputs=input_widgets,
outputs=output_widgets,
show_progress=True,
)
retry_args = dict(
fn=retry,
inputs=input_widgets,
outputs=output_widgets,
show_progress=True,
)
reset_args = dict(
fn=reset_textbox, inputs=[], outputs=[text_box, status_display]
)
predict_events = [
text_box.submit(**transfer_input_args).then(**predict_args),
submitBtn.click(**transfer_input_args).then(**predict_args),
]
emptyBtn.click(reset_state, outputs=output_widgets, show_progress=True)
emptyBtn.click(**reset_args)
retryBtn.click(**retry_args)
delLastBtn.click(
delete_last_conversation,
[chatbot, history],
output_widgets,
show_progress=True,
)
cancelBtn.click(cancel_outputing, [], [status_display], cancels=predict_events)
return demo
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--model_name", type=str, required=True, choices=MODELS, help="model name")
parser.add_argument("--local_path", type=str, default="", help="huggingface ckpt, optional")
parser.add_argument("--ip", type=str, default="0.0.0.0", help="ip address")
parser.add_argument("--port", type=int, default=37913, help="port number")
parser.add_argument("--root_path", type=str, default="", help="root path")
parser.add_argument("--lazy_load", action='store_true')
parser.add_argument("--chunk_size", type=int, default=-1,
help="chunk size for the model for prefiiling. "
"When using 40G gpu for vl2-small, set a chunk_size for incremental_prefilling."
"Otherwise, default value is -1, which means we do not use incremental_prefilling.")
args = parser.parse_args()
demo = build_demo(args)
demo.title = "DeepSeek-VL2 Chatbot"
reload_javascript()
demo.queue(concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS).launch(
# share=False,
share=True,
favicon_path="deepseek_vl2/serve/assets/favicon.ico",
inbrowser=False,
server_name=args.ip,
server_port=args.port,
root_path=args.root_path
)