diff --git a/.gitattributes b/.gitattributes index 1d0afc6..f31a678 100644 --- a/.gitattributes +++ b/.gitattributes @@ -6,3 +6,4 @@ *.jpeg binary *.gif binary *.pdf binary +*.ttc binary diff --git a/README.md b/README.md index bc7e2f6..83c5a77 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,7 @@ Zhiyu Wu*, Xiaokang Chen*, Zizheng Pan*, Xingchao Liu*, Wen Liu**, Damai Dai, Hu ![](./images/vl2_teaser.jpeg) ## 2. Release +✅ 2024-12-25: Gradio Demo Example, Incremental Prefilling and VLMEvalKit Support. ✅ 2024-12-13: DeepSeek-VL2 family released, including DeepSeek-VL2-tiny, DeepSeek-VL2-small, DeepSeek-VL2. ## 3. Model Download @@ -96,7 +97,9 @@ On the basis of `Python >= 3.8` environment, install the necessary dependencies pip install -e . ``` -### Simple Inference Example +### Simple Inference Example with One Image + +**Note: You may need 80GB GPU memory to run this script with deepseek-vl2-small and even larger for deepseek-vl2.** ```python import torch @@ -107,7 +110,7 @@ from deepseek_vl2.utils.io import load_pil_images # specify the path to the model -model_path = "deepseek-ai/deepseek-vl2-small" +model_path = "deepseek-ai/deepseek-vl2-tiny" vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer @@ -119,23 +122,78 @@ conversation = [ { "role": "<|User|>", "content": "\n<|ref|>The giraffe at the back.<|/ref|>.", - "images": ["./images/visual_grounding.jpeg"], + "images": ["./images/visual_grounding_1.jpeg"], }, {"role": "<|Assistant|>", "content": ""}, ] +# load images and prepare for inputs +pil_images = load_pil_images(conversation) +prepare_inputs = vl_chat_processor( + conversations=conversation, + images=pil_images, + force_batchify=True, + system_prompt="" +).to(vl_gpt.device) + +# run image encoder to get the image embeddings +inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + +# run the model to get the response +outputs = vl_gpt.language.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=512, + do_sample=False, + use_cache=True +) + +answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False) +print(f"{prepare_inputs['sft_format'][0]}", answer) +``` + +And the output is something like: +``` +<|User|>: +<|ref|>The giraffe at the back.<|/ref|>. + +<|Assistant|>: <|ref|>The giraffe at the back.<|/ref|><|det|>[[580, 270, 999, 900]]<|/det|><|end▁of▁sentence|> +``` + +### Simple Inference Example with Multiple Images + +**Note: You may need 80GB GPU memory to run this script with deepseek-vl2-small and even larger for deepseek-vl2.** + +```python +import torch +from transformers import AutoModelForCausalLM + +from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM +from deepseek_vl2.utils.io import load_pil_images + + +# specify the path to the model +model_path = "deepseek-ai/deepseek-vl2-tiny" +vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer + +vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) +vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() # multiple images/interleaved image-text -conversation_multi_images = [ +conversation = [ { "role": "<|User|>", "content": "This is image_1: \n" "This is image_2: \n" - "This is image_3: \n If I am a vegetarian, what can I cook with these ingredients?", + "This is image_3: \n Can you tell me what are in the images?", "images": [ - "images/multi_image_1.png", - "images/multi_image_2.jpg", - "images/multi_image_3.jpg", + "images/multi_image_1.jpeg", + "images/multi_image_2.jpeg", + "images/multi_image_3.jpeg", ], }, {"role": "<|Assistant|>", "content": ""} @@ -169,12 +227,151 @@ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False) print(f"{prepare_inputs['sft_format'][0]}", answer) ``` -### Gradio Demo (TODO) +And the output is something like: +``` +<|User|>: This is image_1: +This is image_2: +This is image_3: + Can you tell me what are in the images? + +<|Assistant|>: The images show three different types of vegetables. Image_1 features carrots, which are orange with green tops. Image_2 displays corn cobs, which are yellow with green husks. Image_3 contains raw pork ribs, which are pinkish-red with some marbling.<|end▁of▁sentence|> +``` + +### Simple Inference Example with Incremental Prefilling + +**Note: We use incremental prefilling to inference within 40GB GPU using deepseek-vl2-small.** + +```python +import torch +from transformers import AutoModelForCausalLM + +from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM +from deepseek_vl2.utils.io import load_pil_images -### Demo -This figure present some examples of DeepSeek-VL2. -![](./images/github_demo.png) +# specify the path to the model +model_path = "deepseek-ai/deepseek-vl2-small" +vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer + +vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) +vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + +# multiple images/interleaved image-text +conversation = [ + { + "role": "<|User|>", + "content": "This is image_1: \n" + "This is image_2: \n" + "This is image_3: \n Can you tell me what are in the images?", + "images": [ + "images/multi_image_1.jpeg", + "images/multi_image_2.jpeg", + "images/multi_image_3.jpeg", + ], + }, + {"role": "<|Assistant|>", "content": ""} +] + +# load images and prepare for inputs +pil_images = load_pil_images(conversation) +prepare_inputs = vl_chat_processor( + conversations=conversation, + images=pil_images, + force_batchify=True, + system_prompt="" +).to(vl_gpt.device) + +with torch.no_grad(): + # run image encoder to get the image embeddings + inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + + # incremental_prefilling when using 40G GPU for vl2-small + inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( + input_ids=prepare_inputs.input_ids, + images=prepare_inputs.images, + images_seq_mask=prepare_inputs.images_seq_mask, + images_spatial_crop=prepare_inputs.images_spatial_crop, + attention_mask=prepare_inputs.attention_mask, + chunk_size=512 # prefilling size + ) + + # run the model to get the response + outputs = vl_gpt.generate( + inputs_embeds=inputs_embeds, + input_ids=prepare_inputs.input_ids, + images=prepare_inputs.images, + images_seq_mask=prepare_inputs.images_seq_mask, + images_spatial_crop=prepare_inputs.images_spatial_crop, + attention_mask=prepare_inputs.attention_mask, + past_key_values=past_key_values, + + pad_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=512, + + do_sample=False, + use_cache=True, + ) + + answer = tokenizer.decode(outputs[0][len(prepare_inputs.input_ids[0]):].cpu().tolist(), skip_special_tokens=False) + +print(f"{prepare_inputs['sft_format'][0]}", answer) +``` + +And the output is something like: +``` +<|User|>: This is image_1: +This is image_2: +This is image_3: + Can you tell me what are in the images? + +<|Assistant|>: The first image contains carrots. The second image contains corn. The third image contains meat.<|end▁of▁sentence|> +``` + +### Full Inference Example +```shell +# without incremental prefilling +CUDA_VISIBLE_DEVICES=0 python inference.py --model_patn "deepseek-ai/deepseek-vl2" + +# with incremental prefilling, when using 40G GPU for vl2-small +CUDA_VISIBLE_DEVICES=0 python inference.py --model_patn "deepseek-ai/deepseek-vl2-small" --chunck_size 512 + +``` + + +### Gradio Demo + +* Install the necessary dependencies: +```shell +pip install -e .[gradio] +``` + +* then run the following command: + +```shell +# vl2-tiny, 3.37B-MoE in total, activated 1B, can be run on a single GPU < 40GB +CUDA_VISIBLE_DEVICES=2 python web_demo.py \ +--model_name "deepseek-ai/deepseek-vl2-tiny" \ +--port 37914 + + +# vl2-small, 16.1B-MoE in total, activated 2.4B +# If run on A100 40GB GPU, you need to set the `--chunk_size 512` for incremental prefilling for saving memory and it might be slow. +# If run on > 40GB GPU, you can ignore the `--chunk_size 512` for faster response. +CUDA_VISIBLE_DEVICES=2 python web_demo.py \ +--model_name "deepseek-ai/deepseek-vl2-small" \ +--port 37914 \ +--chunk_size 512 + +# # vl27.5-MoE in total, activated 4.2B +CUDA_VISIBLE_DEVICES=2 python web_demo.py \ +--model_name "deepseek-ai/deepseek-vl2" \ +--port 37914 +``` + +* **Important**: This is a basic and native demo implementation without any deployment optimizations, which may result in slower performance. For production environments, consider using optimized deployment solutions, such as vllm, sglang, lmdeploy, etc. These optimizations will help achieve faster response times and better cost efficiency. ## 5. License @@ -184,13 +381,13 @@ This code repository is licensed under [MIT License](./LICENSE-CODE). The use of ``` @misc{wu2024deepseekvl2mixtureofexpertsvisionlanguagemodels, - title={DeepSeek-VL2: Mixture-of-Experts Vision-Language Models for Advanced Multimodal Understanding}, + title={DeepSeek-VL2: Mixture-of-Experts Vision-Language Models for Advanced Multimodal Understanding}, author={Zhiyu Wu and Xiaokang Chen and Zizheng Pan and Xingchao Liu and Wen Liu and Damai Dai and Huazuo Gao and Yiyang Ma and Chengyue Wu and Bingxuan Wang and Zhenda Xie and Yu Wu and Kai Hu and Jiawei Wang and Yaofeng Sun and Yukun Li and Yishi Piao and Kang Guan and Aixin Liu and Xin Xie and Yuxiang You and Kai Dong and Xingkai Yu and Haowei Zhang and Liang Zhao and Yisong Wang and Chong Ruan}, year={2024}, eprint={2412.10302}, archivePrefix={arXiv}, primaryClass={cs.CV}, - url={https://arxiv.org/abs/2412.10302}, + url={https://arxiv.org/abs/2412.10302}, } ``` diff --git a/deepseek_vl2/models/modeling_deepseek.py b/deepseek_vl2/models/modeling_deepseek.py index 6df4fed..1a84a01 100644 --- a/deepseek_vl2/models/modeling_deepseek.py +++ b/deepseek_vl2/models/modeling_deepseek.py @@ -17,7 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch DeepSeek model.""" +""" PyTorch DeepSeek model and compatible with both DeepSeekV2 and DeepSeekV3""" import math import warnings from typing import List, Optional, Tuple, Union @@ -27,16 +27,13 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint import torch.distributed as dist +from einops import repeat from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_attention_mask, - _prepare_4d_causal_attention_mask, -) +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaFlashAttention2 @@ -63,12 +60,10 @@ from transformers.utils.import_utils import is_torch_fx_available from .configuration_deepseek import DeepseekV2Config - if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): @@ -77,7 +72,6 @@ if is_torch_fx_available(): _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" @@ -869,17 +863,10 @@ class DeepseekV2Attention(nn.Module): compressed_kv, k_pe = torch.split( compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) + compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - k_nope, value_states = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - kv_seq_len = value_states.shape[-2] + kv_seq_len = k_pe.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -888,27 +875,23 @@ class DeepseekV2Attention(nn.Module): "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + compressed_kv = compressed_kv.unsqueeze(1) + k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) + compressed_kv = compressed_kv.squeeze(1) - attn_weights = ( - torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - ) + kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) + q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :] + out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :] + q_nope = torch.matmul(q_nope, q_absorb) + attn_weights = (torch.matmul(q_pe, k_pe.mT) + + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" @@ -925,11 +908,13 @@ class DeepseekV2Attention(nn.Module): # upcast attention to fp32 attn_weights = nn.functional.softmax( attn_weights, dim=-1, dtype=torch.float32 - ).to(query_states.dtype) + ).to(q_pe.dtype) attn_weights = nn.functional.dropout( attn_weights, p=self.attention_dropout, training=self.training ) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) + + attn_output = torch.matmul(attn_output, out_absorb.mT) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( @@ -1034,6 +1019,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention): if self.q_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + # TODO: support compressed_kv for kv_cache (instead of key_states, value_states) in flash_attention version if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update( @@ -1494,6 +1480,7 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions @@ -1668,17 +1655,18 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1730,6 +1718,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position ) hidden_states = outputs[0] @@ -1762,13 +1751,14 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): ) def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs, + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1780,13 +1770,10 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) - if ( - attention_mask is not None - and attention_mask.shape[1] > input_ids.shape[1] - ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1795,9 +1782,9 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] @@ -1807,17 +1794,35 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + position_ids = position_ids[:, -input_ids.shape[1]:] + + if self.generation_config.cache_implementation == "static": + # generation with static cache + cache_position = kwargs.get("cache_position", None) + if cache_position is None: + past_length = 0 + else: + past_length = cache_position[-1] + 1 + input_ids = input_ids[:, past_length:] + position_ids = position_ids[:, past_length:] + + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { - "position_ids": position_ids, + "position_ids": position_ids.contiguous(), + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, @@ -1871,17 +1876,17 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1921,7 +1926,7 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): else: if input_ids is not None: sequence_lengths = ( - torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 ).to(logits.device) else: sequence_lengths = -1 @@ -1937,7 +1942,7 @@ class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int + labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: diff --git a/deepseek_vl2/models/modeling_deepseek_vl_v2.py b/deepseek_vl2/models/modeling_deepseek_vl_v2.py index f2526c6..fa6e182 100644 --- a/deepseek_vl2/models/modeling_deepseek_vl_v2.py +++ b/deepseek_vl2/models/modeling_deepseek_vl_v2.py @@ -1,4 +1,8 @@ from attrdict import AttrDict +from dataclasses import dataclass +import logging +import gc + from einops import rearrange, repeat from typing import Optional, List, Tuple, Callable, Union @@ -6,19 +10,27 @@ import torch import torch.nn as nn import torch.nn.functional as F +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, +) +from transformers.modeling_outputs import ModelOutput from transformers.configuration_utils import PretrainedConfig from transformers import ( AutoConfig, AutoModelForCausalLM, - PreTrainedModel, GenerationConfig, LogitsProcessorList, StoppingCriteriaList, + PreTrainedModel ) -from transformers.generation.utils import GenerateOutput +from transformers.utils import logging from .siglip_vit import VisionTransformer from .configuration_deepseek import DeepseekV2Config from .modeling_deepseek import DeepseekV2ForCausalLM +logger = logging.get_logger(__name__) + + class MlpProjector(nn.Module): def __init__(self, cfg): @@ -181,6 +193,45 @@ class MlpProjectorConfig(PretrainedConfig): super().__init__(**kwargs) +@dataclass +class DeepSeekVLV2CausalLMOutputWithPast(ModelOutput): + """ + Base class for DeepSeek-VL2 causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + class DeepseekVLV2Config(PretrainedConfig): model_type = "deepseek_vl_v2" vision_config: VisionEncoderConfig @@ -229,6 +280,8 @@ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel): def __init__(self, config: DeepseekVLV2Config): super().__init__(config) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + # ----------- vision encoder ------------ vision_config = config.vision_config self.vision = VisionTransformer( @@ -283,8 +336,8 @@ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel): def prepare_inputs_embeds( self, input_ids: torch.LongTensor, - images: torch.FloatTensor, - images_seq_mask: torch.LongTensor, + images: Optional[torch.FloatTensor] = None, + images_seq_mask: Optional[torch.LongTensor] = None, images_spatial_crop: Optional[torch.LongTensor] = None, **ignore_kwargs ): @@ -423,48 +476,222 @@ class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel): return input_embeds - def generate( + @torch.no_grad() + def incremental_prefilling( self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + + images: Optional[torch.FloatTensor] = None, + images_seq_mask: Optional[torch.LongTensor] = None, + images_spatial_crop: Optional[torch.LongTensor] = None, + chunk_size: int = 1024 + ): + if inputs_embeds is None: + inputs_embeds = self.prepare_inputs_embeds( + input_ids=input_ids, + images=images, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + ) + + del images + del images_seq_mask + del images_spatial_crop + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + self._clear_cuda_cache() + + bzs, seq_len, _ = inputs_embeds.shape + past_key_values = None + + # remain the last token for the next forward + prefilling_len = seq_len - 1 + for i in range(0, prefilling_len, chunk_size): + chunk_start = i + chunk_end = min(i + chunk_size, prefilling_len) + chunk_inputs_embeds = inputs_embeds[:, chunk_start: chunk_end] + chunk_attention_mask = attention_mask[:, 0: chunk_end] + # print(f"start = {chunk_start}, end = {chunk_end}, prefilling_len = {prefilling_len}, seq_len = {seq_len}") + + # compute position_ids + if past_key_values is not None: + position_ids = torch.arange( + chunk_start, + chunk_end, + dtype=torch.long, + device=inputs_embeds.device + ).unsqueeze(0) + past_key_values = self._move_past_key_values_to_gpu(past_key_values, inputs_embeds.device) + else: + position_ids = None + + # chunk-forward + with torch.no_grad(): + outputs = self.forward( + inputs_embeds=chunk_inputs_embeds, + attention_mask=chunk_attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + use_cache=True, + ) + # update past_key_values + past_key_values = outputs.past_key_values + past_key_values = self._move_past_key_values_to_cpu(past_key_values) + + del outputs, position_ids + self._clear_cuda_cache() + + prefilling_key_values = [] + for layer_past in past_key_values: + prefilling_key_values.append( + ( + layer_past[0][:, :, 0: prefilling_len, ...].to(inputs_embeds.device), + layer_past[1][:, :, 0: prefilling_len, ...].to(inputs_embeds.device), + ) + ) + + return inputs_embeds, prefilling_key_values + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + + images: Optional[torch.FloatTensor] = None, + images_seq_mask: Optional[torch.LongTensor] = None, + images_spatial_crop: Optional[torch.LongTensor] = None, + + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ): + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if inputs_embeds is None: + inputs_embeds = self.prepare_inputs_embeds( + input_ids=input_ids, + images=images, + images_seq_mask=images_seq_mask, + images_spatial_crop=images_spatial_crop, + ) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # print(inputs_embeds.shape) + outputs = self.language.forward( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position + ) + + self._clear_cuda_cache() + + return outputs + + def _clear_cuda_cache(self): + """clear CUDA memory cache""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def _move_past_key_values_to_cpu(self, past_key_values): + # print(f"past_key_values -> cpu") + if past_key_values is None: + return None + return tuple(tuple(t.cpu() for t in layer) for layer in past_key_values) + + def _move_past_key_values_to_gpu(self, past_key_values, device="cuda:0"): + # print(f"past_key_values -> gpu") + if past_key_values is None: + return None + return tuple(tuple(t.to(device) for t in layer) for layer in past_key_values) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + + images: Optional[torch.FloatTensor] = None, + images_seq_mask: Optional[torch.LongTensor] = None, + images_spatial_crop: Optional[torch.LongTensor] = None, + + attention_mask=None, + cache_position=None, + + pixel_values=None, + image_sizes=None, + num_logits_to_keep=None, **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - r""" - Generates sequences for models with a language modeling head. The method currently supports greedy decoding, - beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. Beam-search decoding - is controlled by the `num_beams` parameter and the `num_return_sequences` parameter. - - Parameters: - - `inputs` (optional) -- `torch.LongTensor` of shape `(batch, sequence_length)`: - The sequence used as a prompt for the generation. If `None`, generate for the model's prompt. - - `generation_config` (optional) -- `GenerationConfig`: - The generation config of the model. - - `logits_processor` (optional) -- `LogitsProcessorList`: - A list of instances of :class:`~transform - """ - - return self.language.generate( - inputs=inputs, - generation_config=generation_config, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - synced_gpus=synced_gpus, - assistant_model=assistant_model, - streamer=streamer, - negative_prompt_ids=negative_prompt_ids, - negative_prompt_attention_mask=negative_prompt_attention_mask, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + model_inputs = self.language.prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, **kwargs, ) + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + cache_position = model_inputs["cache_position"] + if cache_position[0] == 0: + model_inputs["images"] = images + model_inputs["images_seq_mask"] = images_seq_mask + model_inputs["images_spatial_crop"] = images_spatial_crop + + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + AutoConfig.register("vision", VisionEncoderConfig) AutoConfig.register("mlp_projector", MlpProjectorConfig) diff --git a/deepseek_vl2/models/processing_deepseek_vl_v2.py b/deepseek_vl2/models/processing_deepseek_vl_v2.py index 0a98e50..8970dbb 100644 --- a/deepseek_vl2/models/processing_deepseek_vl_v2.py +++ b/deepseek_vl2/models/processing_deepseek_vl_v2.py @@ -559,7 +559,7 @@ class DeepseekVLV2Processor(ProcessorMixin): for j in range(0, best_width, self.image_size): images_list.append( self.image_transform(local_view.crop((j, i, j + self.image_size, i + self.image_size)))) - + """record height / width crop num""" num_width_tiles, num_height_tiles = best_width // self.image_size, best_height // self.image_size images_spatial_crop.append([num_width_tiles, num_height_tiles]) diff --git a/deepseek_vl2/serve/app_modules/presets.py b/deepseek_vl2/serve/app_modules/presets.py index 6146423..b94c40d 100755 --- a/deepseek_vl2/serve/app_modules/presets.py +++ b/deepseek_vl2/serve/app_modules/presets.py @@ -21,7 +21,7 @@ import gradio as gr title = """

Chat with DeepSeek-VL2

""" -description_top = """""" +description_top = """Special Tokens: ``, Visual Grounding: `<|ref|>{query}<|/ref|>`, Grounding Conversation: `<|grounding|>{question}`""" description = """""" CONCURRENT_COUNT = 1 MAX_EVENTS = 10 diff --git a/deepseek_vl2/serve/app_modules/utils.py b/deepseek_vl2/serve/app_modules/utils.py index 7eb3f9e..9a9b98d 100755 --- a/deepseek_vl2/serve/app_modules/utils.py +++ b/deepseek_vl2/serve/app_modules/utils.py @@ -242,7 +242,9 @@ def pil_to_base64( alt: str = "user upload image", resize: bool = True, max_size: int = MAX_IMAGE_SIZE, - min_size: int = MIN_IMAGE_SIZE + min_size: int = MIN_IMAGE_SIZE, + format: str = "JPEG", + quality: int = 95 ) -> str: if resize: @@ -258,15 +260,16 @@ def pil_to_base64( image = image.resize((W, H)) buffered = io.BytesIO() - image.save(buffered, format="JPEG") + image.save(buffered, format=format, quality=quality) img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'{alt}' return img_str -def parse_ref_bbox(response, image): +def parse_ref_bbox(response, image: Image.Image): try: + image = image.copy() image_h, image_w = image.size draw = ImageDraw.Draw(image) @@ -275,7 +278,7 @@ def parse_ref_bbox(response, image): assert len(ref) == len(bbox) if len(ref) == 0: - return + return None boxes, labels = [], [] for box, label in zip(bbox, ref): @@ -301,9 +304,30 @@ def parse_ref_bbox(response, image): text_x = box[0] text_y = box[1] - 20 text_color = box_color - font = ImageFont.truetype('./deepseek_vl2/serve/assets/simsun.ttc', size=20) + font = ImageFont.truetype("deepseek_vl2/serve/assets/simsun.ttc", size=20) draw.text((text_x, text_y), label, font=font, fill=text_color) + # print(f"boxes = {boxes}, labels = {labels}, re-render = {image}") return image except: - return + return None + + +def display_example(image_list): + images_html = "" + for i, img_path in enumerate(image_list): + image = Image.open(img_path) + buffered = io.BytesIO() + image.save(buffered, format="PNG", quality=100) + img_b64_str = base64.b64encode(buffered.getvalue()).decode() + img_str = f'{img_path}' + images_html += img_str + + result_html = f""" +
+
{images_html}
+
+ """ + + return result_html + diff --git a/deepseek_vl2/serve/assets/simsun.ttc b/deepseek_vl2/serve/assets/simsun.ttc index 7590885..e64e92e 100644 Binary files a/deepseek_vl2/serve/assets/simsun.ttc and b/deepseek_vl2/serve/assets/simsun.ttc differ diff --git a/deepseek_vl2/serve/examples/app.png b/deepseek_vl2/serve/examples/app.png deleted file mode 100644 index 5dcd4b0..0000000 Binary files a/deepseek_vl2/serve/examples/app.png and /dev/null differ diff --git a/deepseek_vl2/serve/examples/chart.png b/deepseek_vl2/serve/examples/chart.png deleted file mode 100644 index 64ad76a..0000000 Binary files a/deepseek_vl2/serve/examples/chart.png and /dev/null differ diff --git a/deepseek_vl2/serve/examples/mirror.png b/deepseek_vl2/serve/examples/mirror.png deleted file mode 100644 index 88f2a12..0000000 Binary files a/deepseek_vl2/serve/examples/mirror.png and /dev/null differ diff --git a/deepseek_vl2/serve/examples/pipeline.png b/deepseek_vl2/serve/examples/pipeline.png deleted file mode 100644 index 7acdc57..0000000 Binary files a/deepseek_vl2/serve/examples/pipeline.png and /dev/null differ diff --git a/deepseek_vl2/serve/examples/puzzle.png b/deepseek_vl2/serve/examples/puzzle.png deleted file mode 100644 index f67b8ac..0000000 Binary files a/deepseek_vl2/serve/examples/puzzle.png and /dev/null differ diff --git a/deepseek_vl2/serve/examples/rap.jpeg b/deepseek_vl2/serve/examples/rap.jpeg deleted file mode 100755 index 43f2325..0000000 Binary files a/deepseek_vl2/serve/examples/rap.jpeg and /dev/null differ diff --git a/deepseek_vl2/serve/inference.py b/deepseek_vl2/serve/inference.py index 2fe367d..9b445d6 100755 --- a/deepseek_vl2/serve/inference.py +++ b/deepseek_vl2/serve/inference.py @@ -47,24 +47,27 @@ def load_model(model_path, dtype=torch.bfloat16): def convert_conversation_to_prompts(conversation: Conversation): conv_prompts = [] - pil_images = [] + + last_image = None + messages = conversation.messages for i in range(0, len(messages), 2): if isinstance(messages[i][1], tuple): text, images = messages[i][1] + last_image = images[-1] else: text, images = messages[i][1], [] - pil_images.extend(images) prompt = { "role": messages[i][0], "content": text, + "images": images } response = {"role": messages[i + 1][0], "content": messages[i + 1][1]} conv_prompts.extend([prompt, response]) - return conv_prompts, pil_images + return conv_prompts, last_image class StoppingCriteriaSub(StoppingCriteria): @@ -86,8 +89,7 @@ class StoppingCriteriaSub(StoppingCriteria): @torch.inference_mode() def deepseek_generate( - conv_prompts: list, - pil_images: list, + conversations: list, vl_gpt: torch.nn.Module, vl_chat_processor: DeepseekVLV2Processor, tokenizer: transformers.PreTrainedTokenizer, @@ -95,11 +97,17 @@ def deepseek_generate( max_length: int = 256, temperature: float = 1.0, top_p: float = 1.0, - repetition_penalty=1.1, + repetition_penalty: float = 1.1, + chunk_size: int = -1 ): + pil_images = [] + for message in conversations: + if "images" not in message: + continue + pil_images.extend(message["images"]) prepare_inputs = vl_chat_processor.__call__( - conversations=conv_prompts, + conversations=conversations, images=pil_images, inference_mode=True, force_batchify=True, @@ -110,11 +118,12 @@ def deepseek_generate( vl_gpt, tokenizer, prepare_inputs, - max_length, - temperature, - repetition_penalty, - top_p, - stop_words, + max_gen_len=max_length, + temperature=temperature, + repetition_penalty=repetition_penalty, + top_p=top_p, + stop_words=stop_words, + chunk_size=chunk_size ) @@ -128,11 +137,10 @@ def generate( repetition_penalty=1.1, top_p: float = 0.95, stop_words: List[str] = [], + chunk_size: int = -1 ): """Stream the text output from the multimodality model with prompt and image inputs.""" - inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) - - streamer = TextIteratorStreamer(tokenizer) + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) stop_words_ids = [ torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words @@ -141,9 +149,27 @@ def generate( [StoppingCriteriaSub(stops=stop_words_ids)] ) + if chunk_size != -1: + inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( + input_ids=prepare_inputs.input_ids, + images=prepare_inputs.images, + images_seq_mask=prepare_inputs.images_seq_mask, + images_spatial_crop=prepare_inputs.images_spatial_crop, + attention_mask=prepare_inputs.attention_mask, + chunk_size=chunk_size + ) + else: + inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + past_key_values = None + generation_config = dict( inputs_embeds=inputs_embeds, + input_ids=prepare_inputs.input_ids, + images=prepare_inputs.images, + images_seq_mask=prepare_inputs.images_seq_mask, + images_spatial_crop=prepare_inputs.images_spatial_crop, attention_mask=prepare_inputs.attention_mask, + past_key_values=past_key_values, pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, diff --git a/images/badge.svg b/images/badge.svg deleted file mode 100644 index 1551f56..0000000 --- a/images/badge.svg +++ /dev/null @@ -1 +0,0 @@ -DeepSeek: HomepageDeepSeekHomepage diff --git a/images/dog_a.png b/images/dog_a.png deleted file mode 100644 index 956caab..0000000 Binary files a/images/dog_a.png and /dev/null differ diff --git a/images/dog_b.png b/images/dog_b.png deleted file mode 100644 index 221f1d1..0000000 Binary files a/images/dog_b.png and /dev/null differ diff --git a/images/dog_c.png b/images/dog_c.png deleted file mode 100644 index 283a182..0000000 Binary files a/images/dog_c.png and /dev/null differ diff --git a/images/dog_d.png b/images/dog_d.png deleted file mode 100644 index d9ff5d6..0000000 Binary files a/images/dog_d.png and /dev/null differ diff --git a/images/github_demo.png b/images/github_demo.png deleted file mode 100644 index 8af8865..0000000 Binary files a/images/github_demo.png and /dev/null differ diff --git a/images/grounding_conversation_1.jpeg b/images/grounding_conversation_1.jpeg new file mode 100644 index 0000000..fcff1b7 Binary files /dev/null and b/images/grounding_conversation_1.jpeg differ diff --git a/images/icl_vg_2.jpeg b/images/icl_vg_2.jpeg new file mode 100644 index 0000000..0e30ae0 Binary files /dev/null and b/images/icl_vg_2.jpeg differ diff --git a/images/incontext_visual_grounding_1.jpeg b/images/incontext_visual_grounding_1.jpeg new file mode 100644 index 0000000..d451559 Binary files /dev/null and b/images/incontext_visual_grounding_1.jpeg differ diff --git a/images/monday.jpg b/images/monday.jpg new file mode 100644 index 0000000..01794dc Binary files /dev/null and b/images/monday.jpg differ diff --git a/images/multi_image_1.jpeg b/images/multi_image_1.jpeg new file mode 100644 index 0000000..e8e3f5a Binary files /dev/null and b/images/multi_image_1.jpeg differ diff --git a/images/multi_image_1.png b/images/multi_image_1.png deleted file mode 100644 index 1d619d3..0000000 Binary files a/images/multi_image_1.png and /dev/null differ diff --git a/images/multi_image_2.jpg b/images/multi_image_2.jpeg similarity index 100% rename from images/multi_image_2.jpg rename to images/multi_image_2.jpeg diff --git a/images/multi_image_3.jpg b/images/multi_image_3.jpeg similarity index 100% rename from images/multi_image_3.jpg rename to images/multi_image_3.jpeg diff --git a/images/qr.jpeg b/images/qr.jpeg new file mode 100644 index 0000000..d0152d1 Binary files /dev/null and b/images/qr.jpeg differ diff --git a/images/sample.jpg b/images/sample.jpg new file mode 100644 index 0000000..961e349 Binary files /dev/null and b/images/sample.jpg differ diff --git a/images/vg_2.jpeg b/images/vg_2.jpeg new file mode 100644 index 0000000..5911a6d Binary files /dev/null and b/images/vg_2.jpeg differ diff --git a/images/visual_grounding.jpeg b/images/visual_grounding_1.jpeg similarity index 100% rename from images/visual_grounding.jpeg rename to images/visual_grounding_1.jpeg diff --git a/images/visual_grounding_2.jpg b/images/visual_grounding_2.jpg new file mode 100644 index 0000000..49ddde4 Binary files /dev/null and b/images/visual_grounding_2.jpg differ diff --git a/images/visual_grounding_3.png b/images/visual_grounding_3.png new file mode 100644 index 0000000..bdd848b Binary files /dev/null and b/images/visual_grounding_3.png differ diff --git a/images/vqa_1.jpg b/images/vqa_1.jpg new file mode 100644 index 0000000..608aaa0 Binary files /dev/null and b/images/vqa_1.jpg differ diff --git a/inference.py b/inference.py index 26abce6..4722f33 100644 --- a/inference.py +++ b/inference.py @@ -21,7 +21,6 @@ from argparse import ArgumentParser from typing import List, Dict import torch from transformers import AutoModelForCausalLM - import PIL.Image from deepseek_vl2.models import DeepseekVLV2ForCausalLM, DeepseekVLV2Processor @@ -81,14 +80,37 @@ def main(args): conversation = [ { "role": "<|User|>", - "content": "\n<|ref|>The giraffe at the back.<|/ref|>.", - "images": ["./images/visual_grounding.jpeg"], + "content": "\n\n<|grounding|>In the first image, an object within the red rectangle is marked. Locate the object of the same category in the second image.", + "images": [ + "images/incontext_visual_grounding_1.jpeg", + "images/icl_vg_2.jpeg" + ], }, {"role": "<|Assistant|>", "content": ""}, ] + # conversation = [ + # { + # "role": "<|User|>", + # "content": "\n<|ref|>The giraffe at the back.<|/ref|>.", + # "images": ["./images/visual_grounding_1.jpeg"], + # }, + # {"role": "<|Assistant|>", "content": ""}, + # ] + # load images and prepare for inputs pil_images = load_pil_images(conversation) + print(f"len(pil_images) = {len(pil_images)}") + + # input_ids = batched_input_ids, + # attention_mask = batched_attention_mask, + # labels = batched_labels, + # images_tiles = batched_images, + # images_seq_mask = batched_images_seq_mask, + # images_spatial_crop = batched_images_spatial_crop, + # sft_format = batched_sft_format, + # seq_lens = seq_lens + prepare_inputs = vl_chat_processor.__call__( conversations=conversation, images=pil_images, @@ -96,34 +118,59 @@ def main(args): system_prompt="" ).to(vl_gpt.device, dtype=dtype) + # for key in prepare_inputs.keys(): + # value = prepare_inputs[key] + # if isinstance(value, list): + # print(key, len(value), type(value)) + # elif isinstance(value, torch.Tensor): + # print(key, value.shape, type(value)) + with torch.no_grad(): # run image encoder to get the image embeddings - inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + # inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + + # incremental_prefilling when using 40G GPU for vl2-small + inputs_embeds, past_key_values = vl_gpt.incremental_prefilling( + input_ids=prepare_inputs.input_ids, + images=prepare_inputs.images, + images_seq_mask=prepare_inputs.images_seq_mask, + images_spatial_crop=prepare_inputs.images_spatial_crop, + attention_mask=prepare_inputs.attention_mask, + chunk_size=args.chunk_size + ) # run the model to get the response outputs = vl_gpt.generate( + # inputs_embeds=inputs_embeds[:, -1:], + # input_ids=prepare_inputs.input_ids[:, -1:], inputs_embeds=inputs_embeds, + input_ids=prepare_inputs.input_ids, + images=prepare_inputs.images, + images_seq_mask=prepare_inputs.images_seq_mask, + images_spatial_crop=prepare_inputs.images_spatial_crop, attention_mask=prepare_inputs.attention_mask, + past_key_values=past_key_values, + pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, - max_new_tokens=1024, + max_new_tokens=512, - do_sample=False, + # do_sample=False, # repetition_penalty=1.1, - # do_sample=True, - # temperature=1.0, - # top_p=0.9, - # repetition_penalty=1.1, + do_sample=True, + temperature=0.4, + top_p=0.9, + repetition_penalty=1.1, use_cache=True, ) - answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False) + answer = tokenizer.decode(outputs[0][len(prepare_inputs.input_ids[0]):].cpu().tolist(), skip_special_tokens=False) print(f"{prepare_inputs['sft_format'][0]}", answer) - vg_image = parse_ref_bbox(answer, image=pil_images[0]) + vg_image = parse_ref_bbox(answer, image=pil_images[-1]) if vg_image is not None: vg_image.save("./vg.jpg", format="JPEG", quality=85) @@ -131,7 +178,8 @@ def main(args): if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--model_path", type=str, required=True, - default="deepseek-ai/deepseek-vl2-27b-moe", + default="deepseek-ai/deepseek-vl2", help="model name or local path to the model") + parser.add_argument("--chunk_size", type=int, default=512, help="chunk size for the model for prefiiling") args = parser.parse_args() main(args) diff --git a/web_demo.py b/web_demo.py new file mode 100755 index 0000000..894aece --- /dev/null +++ b/web_demo.py @@ -0,0 +1,674 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +# -*- coding:utf-8 -*- +from argparse import ArgumentParser + +import io +import sys +import base64 +from PIL import Image + +import gradio as gr +import torch + +from deepseek_vl2.serve.app_modules.gradio_utils import ( + cancel_outputing, + delete_last_conversation, + reset_state, + reset_textbox, + wrap_gen_fn, +) +from deepseek_vl2.serve.app_modules.overwrites import reload_javascript +from deepseek_vl2.serve.app_modules.presets import ( + CONCURRENT_COUNT, + MAX_EVENTS, + description, + description_top, + title +) +from deepseek_vl2.serve.app_modules.utils import ( + configure_logger, + is_variable_assigned, + strip_stop_words, + parse_ref_bbox, + pil_to_base64, + display_example +) + +from deepseek_vl2.serve.inference import ( + convert_conversation_to_prompts, + deepseek_generate, + load_model, +) +from deepseek_vl2.models.conversation import SeparatorStyle + +logger = configure_logger() + +MODELS = [ + "DeepSeek-VL2-tiny", + "DeepSeek-VL2-small", + "DeepSeek-VL2", + + "deepseek-ai/deepseek-vl2-tiny", + "deepseek-ai/deepseek-vl2-small", + "deepseek-ai/deepseek-vl2", +] + +DEPLOY_MODELS = dict() +IMAGE_TOKEN = "" + +examples_list = [ + # visual grounding - 1 + [ + ["images/visual_grounding_1.jpeg"], + "<|ref|>The giraffe at the back.<|/ref|>", + ], + + # visual grounding - 2 + [ + ["images/visual_grounding_2.jpg"], + "找到<|ref|>淡定姐<|/ref|>", + ], + + # visual grounding - 3 + [ + ["images/visual_grounding_3.png"], + "Find all the <|ref|>Watermelon slices<|/ref|>", + ], + + # grounding conversation + [ + ["images/grounding_conversation_1.jpeg"], + "<|grounding|>I want to throw out the trash now, what should I do?", + ], + + # in-context visual grounding + [ + [ + "images/incontext_visual_grounding_1.jpeg", + "images/icl_vg_2.jpeg" + ], + "<|grounding|>In the first image, an object within the red rectangle is marked. Locate the object of the same category in the second image." + ], + + # vqa + [ + ["images/vqa_1.jpg"], + "Describe each stage of this image in detail", + ], + + # multi-images + [ + [ + "images/multi_image_1.jpeg", + "images/mi_2.jpeg", + "images/mi_3.jpeg" + ], + "能帮我用这几个食材做一道菜吗?", + ] + +] + + +def fetch_model(model_name: str, dtype=torch.bfloat16): + global args, DEPLOY_MODELS + + if args.local_path: + model_path = args.local_path + else: + model_path = model_name + + if model_name in DEPLOY_MODELS: + model_info = DEPLOY_MODELS[model_name] + print(f"{model_name} has been loaded.") + else: + print(f"{model_name} is loading...") + DEPLOY_MODELS[model_name] = load_model(model_path, dtype=dtype) + print(f"Load {model_name} successfully...") + model_info = DEPLOY_MODELS[model_name] + + return model_info + + +def generate_prompt_with_history( + text, images, history, vl_chat_processor, tokenizer, max_length=2048 +): + """ + Generate a prompt with history for the deepseek application. + + Args: + text (str): The text prompt. + images (list[PIL.Image.Image]): The image prompt. + history (list): List of previous conversation messages. + tokenizer: The tokenizer used for encoding the prompt. + max_length (int): The maximum length of the prompt. + + Returns: + tuple: A tuple containing the generated prompt, image list, conversation, and conversation copy. If the prompt could not be generated within the max_length limit, returns None. + """ + global IMAGE_TOKEN + + sft_format = "deepseek" + user_role_ind = 0 + bot_role_ind = 1 + + # Initialize conversation + conversation = vl_chat_processor.new_chat_template() + + if history: + conversation.messages = history + + if images is not None and len(images) > 0: + + num_image_tags = text.count(IMAGE_TOKEN) + num_images = len(images) + + if num_images > num_image_tags: + pad_image_tags = num_images - num_image_tags + image_tokens = "\n".join([IMAGE_TOKEN] * pad_image_tags) + + # append the in a new line after the text prompt + text = image_tokens + "\n" + text + elif num_images < num_image_tags: + remove_image_tags = num_image_tags - num_images + text = text.replace(IMAGE_TOKEN, "", remove_image_tags) + + # print(f"prompt = {text}, len(images) = {len(images)}") + text = (text, images) + + conversation.append_message(conversation.roles[user_role_ind], text) + conversation.append_message(conversation.roles[bot_role_ind], "") + + # Create a copy of the conversation to avoid history truncation in the UI + conversation_copy = conversation.copy() + logger.info("=" * 80) + logger.info(get_prompt(conversation)) + + rounds = len(conversation.messages) // 2 + + for _ in range(rounds): + current_prompt = get_prompt(conversation) + current_prompt = ( + current_prompt.replace("", "") + if sft_format == "deepseek" + else current_prompt + ) + + if torch.tensor(tokenizer.encode(current_prompt)).size(-1) <= max_length: + return conversation_copy + + if len(conversation.messages) % 2 != 0: + gr.Error("The messages between user and assistant are not paired.") + return + + try: + for _ in range(2): # pop out two messages in a row + conversation.messages.pop(0) + except IndexError: + gr.Error("Input text processing failed, unable to respond in this round.") + return None + + gr.Error("Prompt could not be generated within max_length limit.") + return None + + +def to_gradio_chatbot(conv): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(conv.messages[conv.offset:]): + if i % 2 == 0: + if type(msg) is tuple: + msg, images = msg + + if isinstance(images, list): + for j, image in enumerate(images): + if isinstance(image, str): + with open(image, "rb") as f: + data = f.read() + img_b64_str = base64.b64encode(data).decode() + image_str = (f'') + else: + image_str = pil_to_base64(image, f"user upload image_{j}", max_size=800, min_size=400) + + # replace the tag in the message + msg = msg.replace(IMAGE_TOKEN, image_str, 1) + + else: + pass + + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + +def to_gradio_history(conv): + """Convert the conversation to gradio history state.""" + return conv.messages[conv.offset:] + + +def get_prompt(conv) -> str: + """Get the prompt for generation.""" + system_prompt = conv.system_template.format(system_message=conv.system_message) + if conv.sep_style == SeparatorStyle.DeepSeek: + seps = [conv.sep, conv.sep2] + if system_prompt == "" or system_prompt is None: + ret = "" + else: + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(conv.messages): + if message: + if type(message) is tuple: # multimodal message + message, _ = message + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + else: + return conv.get_prompt() + + +def transfer_input(input_text, input_images): + print("transferring input text and input image") + + return ( + input_text, + input_images, + gr.update(value=""), + gr.update(value=None), + gr.Button(visible=True) + ) + + +@wrap_gen_fn +def predict( + text, + images, + chatbot, + history, + top_p, + temperature, + repetition_penalty, + max_length_tokens, + max_context_length_tokens, + model_select_dropdown, +): + """ + Function to predict the response based on the user's input and selected model. + + Parameters: + user_text (str): The input text from the user. + user_image (str): The input image from the user. + chatbot (str): The chatbot's name. + history (str): The history of the chat. + top_p (float): The top-p parameter for the model. + temperature (float): The temperature parameter for the model. + max_length_tokens (int): The maximum length of tokens for the model. + max_context_length_tokens (int): The maximum length of context tokens for the model. + model_select_dropdown (str): The selected model from the dropdown. + + Returns: + generator: A generator that yields the chatbot outputs, history, and status. + """ + print("running the prediction function") + try: + tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_select_dropdown) + + if text == "": + yield chatbot, history, "Empty context." + return + except KeyError: + yield [[text, "No Model Found"]], [], "No Model Found" + return + + if images is None: + images = [] + + # load images + pil_images = [] + for img_or_file in images: + try: + # load as pil image + if isinstance(images, Image.Image): + pil_images.append(img_or_file) + else: + image = Image.open(img_or_file.name).convert("RGB") + pil_images.append(image) + except Exception as e: + print(f"Error loading image: {e}") + + conversation = generate_prompt_with_history( + text, + pil_images, + history, + vl_chat_processor, + tokenizer, + max_length=max_context_length_tokens, + ) + all_conv, last_image = convert_conversation_to_prompts(conversation) + + stop_words = conversation.stop_str + gradio_chatbot_output = to_gradio_chatbot(conversation) + + full_response = "" + with torch.no_grad(): + for x in deepseek_generate( + conversations=all_conv, + vl_gpt=vl_gpt, + vl_chat_processor=vl_chat_processor, + tokenizer=tokenizer, + stop_words=stop_words, + max_length=max_length_tokens, + temperature=temperature, + repetition_penalty=repetition_penalty, + top_p=top_p, + chunk_size=args.chunk_size + ): + full_response += x + response = strip_stop_words(full_response, stop_words) + conversation.update_last_message(response) + gradio_chatbot_output[-1][1] = response + + # sys.stdout.write(x) + # sys.stdout.flush() + + yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..." + + if last_image is not None: + # TODO always render the last image's visual grounding image + vg_image = parse_ref_bbox(response, last_image) + if vg_image is not None: + vg_base64 = pil_to_base64(vg_image, f"vg", max_size=800, min_size=400) + gradio_chatbot_output[-1][1] += vg_base64 + yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..." + + print("flushed result to gradio") + torch.cuda.empty_cache() + + if is_variable_assigned("x"): + print(f"{model_select_dropdown}:\n{text}\n{'-' * 80}\n{x}\n{'=' * 80}") + print( + f"temperature: {temperature}, " + f"top_p: {top_p}, " + f"repetition_penalty: {repetition_penalty}, " + f"max_length_tokens: {max_length_tokens}" + ) + + yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success" + + +# @wrap_gen_fn +def retry( + text, + images, + chatbot, + history, + top_p, + temperature, + repetition_penalty, + max_length_tokens, + max_context_length_tokens, + model_select_dropdown, +): + if len(history) == 0: + yield (chatbot, history, "Empty context") + return + + chatbot.pop() + history.pop() + text = history.pop()[-1] + if type(text) is tuple: + text, image = text + + yield from predict( + text, + images, + chatbot, + history, + top_p, + temperature, + repetition_penalty, + max_length_tokens, + max_context_length_tokens, + model_select_dropdown, + args.chunk_size + ) + + +def preview_images(files): + if files is None: + return [] + + image_paths = [] + for file in files: + # 使用 file.name 获取文件路径 + # image = Image.open(file.name) + image_paths.append(file.name) + return image_paths # 返回所有图片路径,用于预览 + + +def build_demo(args): + # fetch model + if not args.lazy_load: + fetch_model(args.model_name) + + with open("deepseek_vl2/serve/assets/custom.css", "r", encoding="utf-8") as f: + customCSS = f.read() + + with gr.Blocks(theme=gr.themes.Soft()) as demo: + history = gr.State([]) + input_text = gr.State() + input_images = gr.State() + + with gr.Row(): + gr.HTML(title) + status_display = gr.Markdown("Success", elem_id="status_display") + gr.Markdown(description_top) + + with gr.Row(equal_height=True): + with gr.Column(scale=4): + with gr.Row(): + chatbot = gr.Chatbot( + elem_id="deepseek_chatbot", + show_share_button=True, + bubble_full_width=False, + height=600, + ) + with gr.Row(): + with gr.Column(scale=4): + text_box = gr.Textbox( + show_label=False, placeholder="Enter text", container=False + ) + with gr.Column( + min_width=70, + ): + submitBtn = gr.Button("Send") + with gr.Column( + min_width=70, + ): + cancelBtn = gr.Button("Stop") + with gr.Row(): + emptyBtn = gr.Button( + "🧹 New Conversation", + ) + retryBtn = gr.Button("🔄 Regenerate") + delLastBtn = gr.Button("🗑️ Remove Last Turn") + + with gr.Column(): + upload_images = gr.Files(file_types=["image"], show_label=True) + gallery = gr.Gallery(columns=[3], height="200px", show_label=True) + + upload_images.change(preview_images, inputs=upload_images, outputs=gallery) + + with gr.Tab(label="Parameter Setting") as parameter_row: + top_p = gr.Slider( + minimum=-0, + maximum=1.0, + value=0.9, + step=0.05, + interactive=True, + label="Top-p", + ) + temperature = gr.Slider( + minimum=0, + maximum=1.0, + value=0.1, + step=0.1, + interactive=True, + label="Temperature", + ) + repetition_penalty = gr.Slider( + minimum=0.0, + maximum=2.0, + value=1.1, + step=0.1, + interactive=True, + label="Repetition penalty", + ) + max_length_tokens = gr.Slider( + minimum=0, + maximum=4096, + value=2048, + step=8, + interactive=True, + label="Max Generation Tokens", + ) + max_context_length_tokens = gr.Slider( + minimum=0, + maximum=8192, + value=4096, + step=128, + interactive=True, + label="Max History Tokens", + ) + model_select_dropdown = gr.Dropdown( + label="Select Models", + choices=[args.model_name], + multiselect=False, + value=args.model_name, + interactive=True, + ) + + # show images, but not visible + show_images = gr.HTML(visible=False) + # show_images = gr.Image(type="pil", interactive=False, visible=False) + + def format_examples(examples_list): + examples = [] + for images, texts in examples_list: + examples.append([images, display_example(images), texts]) + + return examples + + gr.Examples( + examples=format_examples(examples_list), + inputs=[upload_images, show_images, text_box], + ) + + gr.Markdown(description) + + input_widgets = [ + input_text, + input_images, + chatbot, + history, + top_p, + temperature, + repetition_penalty, + max_length_tokens, + max_context_length_tokens, + model_select_dropdown, + ] + output_widgets = [chatbot, history, status_display] + + transfer_input_args = dict( + fn=transfer_input, + inputs=[text_box, upload_images], + outputs=[input_text, input_images, text_box, upload_images, submitBtn], + show_progress=True, + ) + + predict_args = dict( + fn=predict, + inputs=input_widgets, + outputs=output_widgets, + show_progress=True, + ) + + retry_args = dict( + fn=retry, + inputs=input_widgets, + outputs=output_widgets, + show_progress=True, + ) + + reset_args = dict( + fn=reset_textbox, inputs=[], outputs=[text_box, status_display] + ) + + predict_events = [ + text_box.submit(**transfer_input_args).then(**predict_args), + submitBtn.click(**transfer_input_args).then(**predict_args), + ] + + emptyBtn.click(reset_state, outputs=output_widgets, show_progress=True) + emptyBtn.click(**reset_args) + retryBtn.click(**retry_args) + + delLastBtn.click( + delete_last_conversation, + [chatbot, history], + output_widgets, + show_progress=True, + ) + + cancelBtn.click(cancel_outputing, [], [status_display], cancels=predict_events) + + return demo + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--model_name", type=str, required=True, choices=MODELS, help="model name") + parser.add_argument("--local_path", type=str, default="", help="huggingface ckpt, optional") + parser.add_argument("--ip", type=str, default="0.0.0.0", help="ip address") + parser.add_argument("--port", type=int, default=37913, help="port number") + parser.add_argument("--root_path", type=str, default="", help="root path") + parser.add_argument("--lazy_load", action='store_true') + parser.add_argument("--chunk_size", type=int, default=-1, + help="chunk size for the model for prefiiling. " + "When using 40G gpu for vl2-small, set a chunk_size for incremental_prefilling." + "Otherwise, default value is -1, which means we do not use incremental_prefilling.") + args = parser.parse_args() + + demo = build_demo(args) + demo.title = "DeepSeek-VL2 Chatbot" + + reload_javascript() + demo.queue(concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS).launch( + # share=False, + share=True, + favicon_path="deepseek_vl2/serve/assets/favicon.ico", + inbrowser=False, + server_name=args.ip, + server_port=args.port, + root_path=args.root_path + )