mirror of
https://github.com/deepseek-ai/DeepSeek-VL.git
synced 2025-04-19 10:09:09 -04:00
version 1.0.0
This commit is contained in:
parent
e24fd228ab
commit
b3e7107168
103
README.md
103
README.md
@ -72,16 +72,21 @@ Introducing DeepSeek VL,
|
||||
|
||||
## 2. Model Downloads
|
||||
|
||||
We release the DeepSeek LLM 7B/67B, including both base and chat models, to the public. To support a broader and more diverse range of research within both academic and commercial communities, we are providing access to the intermediate checkpoints of the base model from its training process. Please **note** that the use of this model is subject to the terms outlined in [License section](#8-license). Commercial usage is permitted under these terms.
|
||||
We release the DeepSeek-VL family, including 1.3B-base, 1.3B-chat, 7b-base and 7b-chat models, to the public.
|
||||
To support a broader and more diverse range of research within both academic and commercial communities.
|
||||
Please note that the use of this model is subject to the terms outlined in [License section](#8-license). Commercial usage is
|
||||
permitted under these terms.
|
||||
|
||||
### Huggingface
|
||||
|
||||
| Model | Sequence Length | Download |
|
||||
|:---------------------:|:---------------:|:-----------------------------------------------------------------------:|
|
||||
| DeepSeek VL 7B Base | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-llm-7b-base) |
|
||||
| DeepSeek VL 7B Chat | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-llm-7b-chat) |
|
||||
| Model | Sequence Length | Download |
|
||||
|-----------------------|-----------------|-----------------------------------------------------------------------------|
|
||||
| DeepSeek-VL-1.3B-base | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-base) |
|
||||
| DeepSeek-VL-1.3B-chat | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-chat) |
|
||||
| DeepSeek-VL-7B-base | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-7b-base) |
|
||||
| DeepSeek-VL-7B-chat | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) |
|
||||
|
||||
## 3. Evaluation Results
|
||||
# 3. Evaluation Results
|
||||
|
||||
### Base Model
|
||||
|
||||
@ -139,53 +144,73 @@ We release the training loss curve and several benchmark metrics curves, as deta
|
||||
On the basis of `Python >= 3.8` environment, install the necessary dependencies by running the following command:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements.txt -e .
|
||||
```
|
||||
|
||||
### Inference with Huggingface's Transformers
|
||||
|
||||
You can directly employ [Huggingface's Transformers](https://github.com/huggingface/transformers) for model inference.
|
||||
|
||||
**Text Completion**
|
||||
**Simple Inference Example**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model_name = "deepseek-ai/deepseek-llm-67b-base"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_name)
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM
|
||||
from deepseek_vlm.utils.io import load_pil_images
|
||||
|
||||
text = "An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
|
||||
|
||||
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
print(result)
|
||||
# specify the path to the model
|
||||
model_path = "deepseek-ai/deepseek-vl-7b-chat"
|
||||
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
|
||||
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "User",
|
||||
"content": "<image_placeholder>Describe each stage of this image.",
|
||||
"images": ["./images/training_pipelines.png"]
|
||||
},
|
||||
{
|
||||
"role": "Assistant",
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
|
||||
# load images and prepare for inputs
|
||||
pil_images = load_pil_images(conversation)
|
||||
prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation,
|
||||
images=pil_images,
|
||||
force_batchify=True
|
||||
).to(vl_gpt.device)
|
||||
|
||||
# run image encoder to get the image embeddings
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
# run the model to get the response
|
||||
outputs = vl_gpt.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=prepare_inputs.attention_mask,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
||||
print(f"{prepare_inputs['sft_format'][0]}", answer)
|
||||
```
|
||||
|
||||
**Chat Completion**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
||||
|
||||
model_name = "deepseek-ai/deepseek-llm-67b-chat"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
|
||||
model.generation_config = GenerationConfig.from_pretrained(model_name)
|
||||
model.generation_config.pad_token_id = model.generation_config.eos_token_id
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Who are you?"}
|
||||
]
|
||||
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
|
||||
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
|
||||
|
||||
result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
|
||||
print(result)
|
||||
**CLI Chat**
|
||||
```bash
|
||||
python cli_chat.py --model_path deepseek-ai/deepseek-vl-7b-chat
|
||||
```
|
||||
|
||||
Avoiding the use of the provided function `apply_chat_template`, you can also interact with our model following the sample template. Note that `messages` should be replaced by your input.
|
||||
|
194
cli_chat.py
Normal file
194
cli_chat.py
Normal file
@ -0,0 +1,194 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from PIL import Image
|
||||
import readline
|
||||
from threading import Thread
|
||||
import torch
|
||||
from transformers import TextIteratorStreamer
|
||||
from deepseek_vlm.utils.io import load_pretrained_model
|
||||
|
||||
|
||||
def load_image(image_file):
|
||||
image = Image.open(image_file).convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def get_help_message(image_token):
|
||||
help_msg = (
|
||||
f"\t\t DeepSeek-VL-Chat is a chatbot that can answer questions based on the given image. Enjoy it! \n"
|
||||
f"Usage: \n"
|
||||
f" 1. type `exit` to quit. \n"
|
||||
f" 2. type `{image_token}` to indicate there is an image. You can enter multiple images, "
|
||||
f"e.g '{image_token} is a dot, {image_token} is a cat, and what is it in {image_token}?'. "
|
||||
f"When you type `{image_token}`, the chatbot will ask you to input image file path. \n"
|
||||
f" 4. type `help` to get the help messages. \n"
|
||||
f" 5. type `new` to start a new conversation. \n"
|
||||
f" Here is an example, you can type: '<image_placeholder>Describe the image.'\n"
|
||||
)
|
||||
|
||||
return help_msg
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||
|
||||
prompt = conv.get_prompt()
|
||||
prepare_inputs = vl_chat_processor.__call__(
|
||||
prompt=prompt,
|
||||
images=pil_images,
|
||||
force_batchify=True
|
||||
).to(vl_gpt.device)
|
||||
|
||||
# run image encoder to get the image embeddings
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=tokenizer,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True
|
||||
)
|
||||
generation_config["inputs_embeds"] = inputs_embeds
|
||||
generation_config["attention_mask"] = prepare_inputs.attention_mask
|
||||
generation_config["streamer"] = streamer
|
||||
|
||||
thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config)
|
||||
thread.start()
|
||||
|
||||
yield from streamer
|
||||
|
||||
|
||||
def get_user_input(hint: str):
|
||||
user_input = ""
|
||||
while user_input == "":
|
||||
try:
|
||||
user_input = input(f"{hint}")
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
continue
|
||||
except EOFError:
|
||||
user_input = "exit"
|
||||
|
||||
return user_input
|
||||
|
||||
|
||||
def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||
image_token = vl_chat_processor.image_token
|
||||
help_msg = get_help_message(image_token)
|
||||
|
||||
while True:
|
||||
|
||||
print(help_msg)
|
||||
|
||||
pil_images = []
|
||||
conv = vl_chat_processor.new_chat_template()
|
||||
roles = conv.roles
|
||||
|
||||
while True:
|
||||
|
||||
# get user input
|
||||
user_input = get_user_input(f"{roles[0]} [{image_token} indicates an image]: ")
|
||||
|
||||
if user_input == "exit":
|
||||
print("Chat program exited.")
|
||||
sys.exit(0)
|
||||
|
||||
elif user_input == "help":
|
||||
print(help_msg)
|
||||
|
||||
elif user_input == "new":
|
||||
os.system("clear")
|
||||
pil_images = []
|
||||
conv = vl_chat_processor.new_chat_template()
|
||||
torch.cuda.empty_cache()
|
||||
print("New conversation started.")
|
||||
|
||||
else:
|
||||
conv.append_message(conv.roles[0], user_input)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
|
||||
# check if the user input is an image token
|
||||
num_images = user_input.count(image_token)
|
||||
cur_img_idx = 0
|
||||
|
||||
while cur_img_idx < num_images:
|
||||
try:
|
||||
image_file = input(f"({cur_img_idx + 1}/{num_images}) Input the image file path: ")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
continue
|
||||
|
||||
except EOFError:
|
||||
image_file = None
|
||||
|
||||
if image_file and os.path.exists(image_file):
|
||||
pil_image = load_image(image_file)
|
||||
pil_images.append(pil_image)
|
||||
cur_img_idx += 1
|
||||
|
||||
elif image_file == "exit":
|
||||
print("Chat program exited.")
|
||||
sys.exit(0)
|
||||
|
||||
else:
|
||||
print(f"File error, `{image_file}` does not exist. Please input the correct file path.")
|
||||
|
||||
# get the answer by the model's prediction
|
||||
answer = ""
|
||||
answer_iter = response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config)
|
||||
sys.stdout.write(f"{conv.roles[1]}: ")
|
||||
for char in answer_iter:
|
||||
answer += char
|
||||
sys.stdout.write(char)
|
||||
sys.stdout.flush()
|
||||
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
conv.messages[-1][-1] = answer
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
# setup
|
||||
tokenizer, vl_chat_processor, vl_gpt = load_pretrained_model(args.model_path)
|
||||
generation_config = dict(
|
||||
pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
|
||||
bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
|
||||
eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
|
||||
max_new_tokens=args.max_gen_len,
|
||||
use_cache=True,
|
||||
)
|
||||
if args.temperature > 0:
|
||||
generation_config.update({
|
||||
"do_sample": True,
|
||||
"top_p": args.top_p,
|
||||
"temperature": args.temperature,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
})
|
||||
else:
|
||||
generation_config.update({"do_sample": False})
|
||||
|
||||
chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl-7b-chat",
|
||||
help="the huggingface model name or the local path of the downloaded huggingface model.")
|
||||
parser.add_argument("--temperature", type=float, default=0.2)
|
||||
parser.add_argument("--top_p", type=float, default=0.95)
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.1)
|
||||
parser.add_argument("--max_gen_len", type=int, default=512)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
"""
|
||||
|
||||
CUDA_VISIBLE_DEVICES=4 python cli.py --model_path "/home/liuwen/3fs_shared/ckpts/deepseek-vl-7b-chat"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=2 python cli.py --model_path "/home/liuwen/3fs_shared/ckpts/siglip_1B_HF"
|
||||
|
||||
"""
|
0
deepseek_vlm/__init__.py
Normal file
0
deepseek_vlm/__init__.py
Normal file
5
deepseek_vlm/models/__init__.py
Normal file
5
deepseek_vlm/models/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
|
||||
from .image_processing_vlm import VLMImageProcessor
|
||||
from .processing_vlm import VLChatProcessor
|
||||
from .modeling_vlm import MultiModalityCausalLM
|
||||
|
213
deepseek_vlm/models/clip_encoder.py
Normal file
213
deepseek_vlm/models/clip_encoder.py
Normal file
@ -0,0 +1,213 @@
|
||||
from typing import Tuple, Union, List, Dict, Optional, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms
|
||||
from einops import rearrange
|
||||
|
||||
from deepseek_vlm.models.siglip_vit import create_siglip_vit
|
||||
from deepseek_vlm.models.sam import create_sam_vit
|
||||
|
||||
|
||||
class CLIPVisionTower(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
model_name: str = "siglip_large_patch16_384",
|
||||
image_size: Union[Tuple[int, int], int] = 336,
|
||||
select_feature: str = "patch",
|
||||
select_layer: int = -2,
|
||||
select_layers: list = None,
|
||||
ckpt_path: str = "",
|
||||
pixel_mean: Optional[List[float]] = None,
|
||||
pixel_std: Optional[List[float]] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.model_name = model_name
|
||||
self.select_feature = select_feature
|
||||
self.select_layer = select_layer
|
||||
self.select_layers = select_layers
|
||||
|
||||
vision_tower_params = {
|
||||
"model_name": model_name,
|
||||
"image_size": image_size,
|
||||
"ckpt_path": ckpt_path,
|
||||
"select_layer": select_layer
|
||||
}
|
||||
vision_tower_params.update(kwargs)
|
||||
self.vision_tower, self.forward_kwargs = self.build_vision_tower(vision_tower_params)
|
||||
|
||||
if pixel_mean is not None and pixel_std is not None:
|
||||
image_norm = torchvision.transforms.Normalize(mean=pixel_mean, std=pixel_std)
|
||||
else:
|
||||
image_norm = None
|
||||
|
||||
self.image_norm = image_norm
|
||||
|
||||
def build_vision_tower(self, vision_tower_params):
|
||||
|
||||
if self.model_name.startswith("siglip"):
|
||||
self.select_feature = "same"
|
||||
vision_tower = create_siglip_vit(**vision_tower_params)
|
||||
forward_kwargs = dict()
|
||||
|
||||
elif self.model_name.startswith("sam"):
|
||||
vision_tower = create_sam_vit(**vision_tower_params)
|
||||
forward_kwargs = dict()
|
||||
|
||||
else: # huggingface
|
||||
from transformers import CLIPVisionModel
|
||||
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
|
||||
forward_kwargs = dict(output_hidden_states=True)
|
||||
|
||||
return vision_tower, forward_kwargs
|
||||
|
||||
def feature_select(self, image_forward_outs):
|
||||
|
||||
if isinstance(image_forward_outs, torch.Tensor):
|
||||
# the output has been the self.select_layer"s features
|
||||
image_features = image_forward_outs
|
||||
else:
|
||||
image_features = image_forward_outs.hidden_states[self.select_layer]
|
||||
|
||||
if self.select_feature == "patch":
|
||||
# if the output has cls_token
|
||||
image_features = image_features[:, 1:]
|
||||
elif self.select_feature == "cls_patch":
|
||||
image_features = image_features
|
||||
elif self.select_feature == "same":
|
||||
image_features = image_features
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
return image_features
|
||||
|
||||
def forward(self, images):
|
||||
"""
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): [b, 3, H, W]
|
||||
|
||||
Returns:
|
||||
image_features (torch.Tensor): [b, n_patch, d]
|
||||
"""
|
||||
|
||||
if self.image_norm is not None:
|
||||
images = self.image_norm(images)
|
||||
|
||||
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
|
||||
image_features = self.feature_select(image_forward_outs)
|
||||
return image_features
|
||||
|
||||
|
||||
class HybridVisionTower(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
high_res_cfg: Dict,
|
||||
low_res_cfg: Dict,
|
||||
freeze_high: bool = False,
|
||||
freeze_low: bool = False,
|
||||
concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple",
|
||||
**ignore_kwargs):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.vision_tower_high = CLIPVisionTower(**high_res_cfg)
|
||||
self.vision_tower_low = CLIPVisionTower(**low_res_cfg)
|
||||
self.low_res_size = low_res_cfg["image_size"]
|
||||
self.concat_type = concat_type
|
||||
|
||||
self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024))
|
||||
self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024))
|
||||
|
||||
if freeze_high:
|
||||
for p_name, p in self.vision_tower_high.named_parameters():
|
||||
p.requires_grad = False
|
||||
self.vision_tower_high = self.vision_tower_high.eval()
|
||||
else:
|
||||
# train donwsamples and neck
|
||||
for p_name, p in self.vision_tower_high.named_parameters():
|
||||
if "downsamples" in p_name or "neck" in p_name:
|
||||
p.requires_grad = True
|
||||
else:
|
||||
p.requires_grad = False
|
||||
|
||||
if freeze_low:
|
||||
for p in self.vision_tower_low.parameters():
|
||||
p.requires_grad = False
|
||||
self.vision_tower_low = self.vision_tower_low.eval()
|
||||
|
||||
self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True)
|
||||
|
||||
def forward(self, images: torch.Tensor):
|
||||
"""
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): [bs, 3, H, W]
|
||||
|
||||
Returns:
|
||||
res (torch.Tensor): [bs, t, c]
|
||||
"""
|
||||
|
||||
# [bs, c, h, w]
|
||||
high_images = images
|
||||
|
||||
# [bs, c, h_low, w_low]
|
||||
low_images = self.resize(images)
|
||||
|
||||
# separately run two vision towers
|
||||
# run high_res vision tower
|
||||
high_res = self.vision_tower_high(high_images)
|
||||
# [bs, c, h, w] -> [bs, h*w, c]
|
||||
high_res = rearrange(high_res, "b c h w -> b (h w) c")
|
||||
# run low_res vision tower
|
||||
low_res = self.vision_tower_low(low_images)
|
||||
|
||||
if self.concat_type == "feature":
|
||||
images_features = torch.cat([high_res, low_res], dim=-1)
|
||||
elif self.concat_type == "sequence":
|
||||
images_features = torch.cat([high_res, low_res], dim=1)
|
||||
elif self.concat_type == "add":
|
||||
images_features = high_res + low_res
|
||||
elif self.concat_type == "tuple":
|
||||
images_features = (high_res, low_res)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Currently only support `feature`, `sequence`, `add` and `tuple` concat type.")
|
||||
|
||||
return images_features
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_size = 1024
|
||||
x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda()
|
||||
|
||||
high_res_cfg = dict(
|
||||
model_name="sam_b_downsample",
|
||||
select_feature="same",
|
||||
image_size=image_size,
|
||||
pixel_mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
pixel_std=(0.26862954, 0.26130258, 0.27577711),
|
||||
select_layer=-1,
|
||||
ckpt_path=""
|
||||
)
|
||||
|
||||
low_res_cfg = dict(
|
||||
model_name="siglip_large_patch16_384",
|
||||
select_feature="same",
|
||||
image_size=384,
|
||||
pixel_mean=(0.5, 0.5, 0.5),
|
||||
pixel_std=(0.5, 0.5, 0.5),
|
||||
select_layer=-1,
|
||||
ckpt_path=""
|
||||
)
|
||||
|
||||
net = HybridVisionTower(
|
||||
high_res_cfg=high_res_cfg,
|
||||
low_res_cfg=low_res_cfg,
|
||||
freeze_high=True,
|
||||
freeze_low=True,
|
||||
concat_type="tuple"
|
||||
).bfloat16().cuda()
|
||||
high_x, low_x = net(x)
|
||||
print(x.shape, high_x.shape, low_x.shape)
|
163
deepseek_vlm/models/image_processing_vlm.py
Normal file
163
deepseek_vlm/models/image_processing_vlm.py
Normal file
@ -0,0 +1,163 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms.functional
|
||||
from typing import List, Union, Tuple
|
||||
from transformers import PretrainedConfig, AutoImageProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
||||
from transformers.image_utils import to_numpy_array
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
||||
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||||
|
||||
|
||||
def expand2square(pil_img, background_color):
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||
result.paste(pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||
result.paste(pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
|
||||
|
||||
class VLMImageProcessorConfig(PretrainedConfig):
|
||||
|
||||
model_type = "deepseek_vlm"
|
||||
image_size: int
|
||||
min_size: int
|
||||
image_mean: Union[Tuple[float, float, float], List[float]]
|
||||
image_std: Union[Tuple[float, float, float], List[float]]
|
||||
rescale_factor: float
|
||||
do_normalize: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
min_size: int = 14,
|
||||
image_mean: Union[Tuple[float, float, float], List[float]] = (0.48145466, 0.4578275, 0.40821073),
|
||||
image_std: Union[Tuple[float, float, float], List[float]] = (0.26862954, 0.26130258, 0.27577711),
|
||||
rescale_factor: float = 1.0 / 255.0,
|
||||
do_normalize: bool = True, **kwargs
|
||||
):
|
||||
|
||||
self.image_size = image_size
|
||||
self.min_size = min_size
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class VLMImageProcessor(BaseImageProcessor):
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int,
|
||||
min_size: int = 14,
|
||||
image_mean: Union[Tuple[float, float, float], List[float]] = (0.48145466, 0.4578275, 0.40821073),
|
||||
image_std: Union[Tuple[float, float, float], List[float]] = (0.26862954, 0.26130258, 0.27577711),
|
||||
rescale_factor: float = 1.0 / 255.0,
|
||||
do_normalize: bool = True, **kwargs
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.rescale_factor = rescale_factor
|
||||
self.image_mean = image_mean
|
||||
self.image_std = image_std
|
||||
self.min_size = min_size
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
if image_mean is None:
|
||||
self.background_color = (127, 127, 127)
|
||||
else:
|
||||
self.background_color = tuple([int(x * 255) for x in image_mean])
|
||||
|
||||
def resize(self, pil_img: Image) -> np.ndarray:
|
||||
"""
|
||||
|
||||
Args:
|
||||
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
|
||||
|
||||
Returns:
|
||||
x (np.ndarray): [3, self.image_size, self.image_size]
|
||||
"""
|
||||
|
||||
width, height = pil_img.size
|
||||
max_size = max(width, height)
|
||||
|
||||
size = [
|
||||
max(int(height / max_size * self.image_size), self.min_size),
|
||||
max(int(width / max_size * self.image_size), self.min_size)
|
||||
]
|
||||
|
||||
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
|
||||
print(f"orig size = {pil_img.size}, new size = {size}")
|
||||
raise ValueError("Invalid size!")
|
||||
|
||||
pil_img = torchvision.transforms.functional.resize(
|
||||
pil_img, size, interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, antialias=True
|
||||
)
|
||||
|
||||
pil_img = expand2square(pil_img, self.background_color)
|
||||
x = to_numpy_array(pil_img)
|
||||
|
||||
# [H, W, 3] -> [3, H, W]
|
||||
x = np.transpose(x, (2, 0, 1))
|
||||
|
||||
return x
|
||||
|
||||
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
|
||||
|
||||
# resize and pad to [self.image_size, self.image_size]
|
||||
# then convert from [H, W, 3] to [3, H, W]
|
||||
images: List[np.ndarray] = [self.resize(image) for image in images]
|
||||
|
||||
# resacle from [0, 255] -> [0, 1]
|
||||
images = [
|
||||
self.rescale(image=image, scale=self.rescale_factor, input_data_format="channels_first")
|
||||
for image in images
|
||||
]
|
||||
|
||||
# normalize
|
||||
if self.do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=self.image_mean, std=self.image_std,
|
||||
input_data_format="channels_first")
|
||||
for image in images
|
||||
]
|
||||
|
||||
data = {"pixel_values": images}
|
||||
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
@property
|
||||
def default_shape(self):
|
||||
return [3, self.image_size, self.image_size]
|
||||
|
||||
|
||||
# AutoConfig.register("deepseek_vlm", VLMImageProcessorConfig)
|
||||
AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
image_processor = VLMImageProcessor(
|
||||
image_size=1024,
|
||||
image_mean=IMAGENET_INCEPTION_MEAN,
|
||||
image_std=IMAGENET_INCEPTION_STD,
|
||||
do_normalize=True
|
||||
)
|
150
deepseek_vlm/models/modeling_vlm.py
Normal file
150
deepseek_vlm/models/modeling_vlm.py
Normal file
@ -0,0 +1,150 @@
|
||||
from attrdict import AttrDict
|
||||
from einops import rearrange
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
PreTrainedModel,
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM
|
||||
)
|
||||
|
||||
from deepseek_vlm.models.projector import MlpProjector
|
||||
from deepseek_vlm.models.clip_encoder import CLIPVisionTower, HybridVisionTower
|
||||
|
||||
|
||||
def model_name_to_cls(cls_name):
|
||||
|
||||
if "MlpProjector" in cls_name:
|
||||
cls = MlpProjector
|
||||
|
||||
elif "CLIPVisionTower" in cls_name:
|
||||
cls = CLIPVisionTower
|
||||
|
||||
elif "HybridVisionTower" in cls_name:
|
||||
cls = HybridVisionTower
|
||||
|
||||
else:
|
||||
raise ValueError(f"class_name {cls_name} is invalid.")
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
class VisionConfig(PretrainedConfig):
|
||||
model_type = "vision"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class AlignerConfig(PretrainedConfig):
|
||||
model_type = "aligner"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class MultiModalityConfig(PretrainedConfig):
|
||||
model_type = "multi_modality"
|
||||
vision_config: VisionConfig
|
||||
aligner_config: AlignerConfig
|
||||
language_config: LlamaConfig
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
vision_config = kwargs.get("vision_config", {})
|
||||
self.vision_config = VisionConfig(**vision_config)
|
||||
|
||||
aligner_config = kwargs.get("aligner_config", {})
|
||||
self.aligner_config = AlignerConfig(**aligner_config)
|
||||
|
||||
language_config = kwargs.get("language_config", {})
|
||||
if isinstance(language_config, LlamaConfig):
|
||||
self.language_config = language_config
|
||||
else:
|
||||
self.language_config = LlamaConfig(**language_config)
|
||||
|
||||
|
||||
class MultiModalityPreTrainedModel(PreTrainedModel):
|
||||
config_class = MultiModalityConfig
|
||||
base_model_prefix = "multi_modality"
|
||||
_no_split_modules = []
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
|
||||
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
|
||||
def __init__(self, config: MultiModalityConfig):
|
||||
super().__init__(config)
|
||||
|
||||
vision_config = config.vision_config
|
||||
vision_cls = model_name_to_cls(vision_config.cls)
|
||||
self.vision_model = vision_cls(**vision_config.params)
|
||||
|
||||
aligner_config = config.aligner_config
|
||||
aligner_cls = model_name_to_cls(aligner_config.cls)
|
||||
self.aligner = aligner_cls(aligner_config.params)
|
||||
|
||||
language_config = config.language_config
|
||||
self.language_model = LlamaForCausalLM(language_config)
|
||||
|
||||
def prepare_inputs_embeds(self,
|
||||
input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
images_seq_mask: torch.LongTensor,
|
||||
images_emb_mask: torch.LongTensor, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids (torch.LongTensor): [b, T]
|
||||
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
||||
images_seq_mask (torch.BoolTensor): [b, T]
|
||||
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
||||
|
||||
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
||||
|
||||
Returns:
|
||||
input_embeds (torch.Tensor): [b, T, D]
|
||||
"""
|
||||
|
||||
bs, n = pixel_values.shape[0:2]
|
||||
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
||||
# [b x n, T2, D]
|
||||
images_embeds = self.aligner(self.vision_model(images))
|
||||
|
||||
# [b x n, T2, D] -> [b, n x T2, D]
|
||||
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
||||
# [b, n, T2] -> [b, n x T2]
|
||||
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
||||
|
||||
# [b, T, D]
|
||||
input_ids[input_ids < 0] = 0 # ignore the image embeddings
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
|
||||
# replace with the image embeddings
|
||||
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
AutoConfig.register("vision", VisionConfig)
|
||||
AutoConfig.register("aligner", AlignerConfig)
|
||||
AutoConfig.register("multi_modality", MultiModalityConfig)
|
||||
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|
351
deepseek_vlm/models/processing_vlm.py
Normal file
351
deepseek_vlm/models/processing_vlm.py
Normal file
@ -0,0 +1,351 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from typing import List, Dict, Union
|
||||
import torch
|
||||
|
||||
from transformers import AutoTokenizer, AutoImageProcessor
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
from deepseek_vlm.models.image_processing_vlm import VLMImageProcessor
|
||||
from deepseek_vlm.utils.conversation import get_conv_template
|
||||
|
||||
|
||||
class DictOutput(object):
|
||||
|
||||
def keys(self):
|
||||
return self.__dict__.keys()
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.__dict__[item]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__[key] = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLChatProcessorOutput(DictOutput):
|
||||
sft_format: str
|
||||
input_ids: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
num_image_tokens: torch.IntTensor
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedVLChatProcessorOutput(DictOutput):
|
||||
sft_format: List[str]
|
||||
input_ids: torch.Tensor
|
||||
pixel_values: torch.Tensor
|
||||
attention_mask: torch.Tensor
|
||||
images_seq_mask: torch.BoolTensor
|
||||
images_emb_mask: torch.BoolTensor
|
||||
|
||||
def to(self, device, dtype=torch.bfloat16):
|
||||
self.input_ids = self.input_ids.to(device)
|
||||
self.attention_mask = self.attention_mask.to(device)
|
||||
self.images_seq_mask = self.images_seq_mask.to(device)
|
||||
self.images_emb_mask = self.images_emb_mask.to(device)
|
||||
self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
class VLChatProcessor(ProcessorMixin):
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
||||
system_prompt = ("You are a helpful language and vision assistant. "
|
||||
"You are able to understand the visual content that the user provides, "
|
||||
"and assist the user with a variety of tasks using natural language.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: VLMImageProcessor,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
image_tag: str = "<image_placeholder>",
|
||||
num_image_tokens: int = 576,
|
||||
add_special_token: bool = False,
|
||||
sft_format: str = "deepseek",
|
||||
mask_prompt: bool = True,
|
||||
ignore_id: int = -100, **kwargs
|
||||
):
|
||||
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
image_id = self.tokenizer.vocab.get(image_tag)
|
||||
if image_id is None:
|
||||
special_tokens = [image_tag]
|
||||
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
||||
self.tokenizer.add_special_tokens(special_tokens_dict)
|
||||
print(f"Add image tag = {image_tag} to the tokenizer")
|
||||
|
||||
self.image_tag = image_tag
|
||||
self.num_image_tokens = num_image_tokens
|
||||
self.add_special_token = add_special_token
|
||||
self.sft_format = sft_format
|
||||
self.mask_prompt = mask_prompt
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
super().__init__(image_processor, tokenizer, image_tag, num_image_tokens, add_special_token,
|
||||
sft_format, mask_prompt, ignore_id, **kwargs)
|
||||
|
||||
def new_chat_template(self):
|
||||
conv = get_conv_template(self.sft_format)
|
||||
conv.set_system_message(self.system_prompt)
|
||||
return conv
|
||||
|
||||
def apply_sft_template_for_multi_turn_prompts(
|
||||
self,
|
||||
conversations: List[Dict[str, str]],
|
||||
sft_format: str = "deepseek",
|
||||
system_prompt: str = ""
|
||||
):
|
||||
"""
|
||||
Applies the SFT template to conversation.
|
||||
|
||||
An example of conversation:
|
||||
conversation = [
|
||||
{
|
||||
"role": "User",
|
||||
"content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
|
||||
"images": [
|
||||
"./multi-images/attribute_comparison_1.png",
|
||||
"./multi-images/attribute_comparison_2.png"
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "Assistant",
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
|
||||
Args:
|
||||
conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
|
||||
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
|
||||
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
|
||||
|
||||
Returns:
|
||||
sft_prompt (str): The formatted text.
|
||||
"""
|
||||
|
||||
conv = get_conv_template(sft_format)
|
||||
conv.set_system_message(system_prompt)
|
||||
for message in conversations:
|
||||
conv.append_message(message["role"], message["content"].strip())
|
||||
sft_prompt = conv.get_prompt().strip()
|
||||
|
||||
return sft_prompt
|
||||
|
||||
@property
|
||||
def image_token(self):
|
||||
return self.image_tag
|
||||
|
||||
@property
|
||||
def image_id(self):
|
||||
image_id = self.tokenizer.vocab.get(self.image_tag)
|
||||
return image_id
|
||||
|
||||
@property
|
||||
def pad_id(self):
|
||||
pad_id = self.tokenizer.pad_token_id
|
||||
if pad_id is None:
|
||||
pad_id = self.tokenizer.eos_token_id
|
||||
|
||||
return pad_id
|
||||
|
||||
def add_image_token(
|
||||
self,
|
||||
image_indices: List[int],
|
||||
input_ids: torch.LongTensor,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
image_indices (List[int]): [index_0, index_1, ..., index_j]
|
||||
input_ids (torch.LongTensor): [N]
|
||||
|
||||
Returns:
|
||||
input_ids (torch.LongTensor): [N + image tokens]
|
||||
num_image_tokens (torch.IntTensor): [n_images]
|
||||
"""
|
||||
|
||||
input_slices = []
|
||||
|
||||
start = 0
|
||||
for index in image_indices:
|
||||
if self.add_special_token:
|
||||
end = index + 1
|
||||
else:
|
||||
end = index
|
||||
|
||||
# original text tokens
|
||||
input_slices.append(input_ids[start: end])
|
||||
|
||||
# add image tokens, and set the mask as False
|
||||
input_slices.append(self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long))
|
||||
start = index + 1
|
||||
|
||||
# the left part
|
||||
input_slices.append(input_ids[start:])
|
||||
|
||||
# concat all slices
|
||||
input_ids = torch.cat(input_slices, dim=0)
|
||||
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
|
||||
|
||||
return input_ids, num_image_tokens
|
||||
|
||||
def process_one(
|
||||
self,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
conversations (List[Dict]): conversations with a list of messages;
|
||||
images (List[ImageType]): the list of images;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- target_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
assert prompt is None or conversations is None, "prompt and conversations cannot be used at the same time."
|
||||
|
||||
if prompt is None:
|
||||
# apply sft format
|
||||
sft_format = self.apply_sft_template_for_multi_turn_prompts(
|
||||
conversations=conversations,
|
||||
sft_format=self.sft_format,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
else:
|
||||
sft_format = prompt
|
||||
|
||||
# tokenize
|
||||
input_ids = self.tokenizer.encode(sft_format)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
|
||||
# add image tokens to the input_ids
|
||||
image_token_mask: torch.BoolTensor = input_ids == self.image_id
|
||||
image_indices = image_token_mask.nonzero()
|
||||
input_ids, num_image_tokens = self.add_image_token(
|
||||
image_indices=image_indices,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
# load images
|
||||
images_outputs = self.image_processor(images, return_tensors="pt")
|
||||
|
||||
prepare = VLChatProcessorOutput(
|
||||
sft_format=sft_format,
|
||||
input_ids=input_ids,
|
||||
pixel_values=images_outputs.pixel_values,
|
||||
num_image_tokens=num_image_tokens
|
||||
)
|
||||
|
||||
return prepare
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
prompt: str = None,
|
||||
conversations: List[Dict[str, str]] = None,
|
||||
images: List[Image] = None,
|
||||
force_batchify: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
prompt (str): the formatted prompt;
|
||||
conversations (List[Dict]): conversations with a list of messages;
|
||||
images (List[ImageType]): the list of images;
|
||||
force_batchify (bool): force batchify the inputs;
|
||||
**kwargs:
|
||||
|
||||
Returns:
|
||||
outputs (BaseProcessorOutput): the output of the processor,
|
||||
- input_ids (torch.LongTensor): [N + image tokens]
|
||||
- images (torch.FloatTensor): [n_images, 3, H, W]
|
||||
- image_id (int): the id of the image token
|
||||
- num_image_tokens (List[int]): the number of image tokens
|
||||
"""
|
||||
|
||||
prepare = self.process_one(prompt=prompt, conversations=conversations, images=images)
|
||||
|
||||
if force_batchify:
|
||||
prepare = self.batchify([prepare])
|
||||
|
||||
return prepare
|
||||
|
||||
def batchify(self, prepare_list: List[VLChatProcessorOutput]) -> BatchedVLChatProcessorOutput:
|
||||
"""
|
||||
Preprocesses the inputs for multimodal inference.
|
||||
|
||||
Args:
|
||||
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
|
||||
|
||||
Returns:
|
||||
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
|
||||
"""
|
||||
|
||||
batch_size = len(prepare_list)
|
||||
sft_format = []
|
||||
n_images = []
|
||||
seq_lens = []
|
||||
for prepare in prepare_list:
|
||||
n_images.append(len(prepare.num_image_tokens))
|
||||
seq_lens.append(len(prepare))
|
||||
|
||||
input_token_max_len = max(seq_lens)
|
||||
max_n_images = max(1, max(n_images))
|
||||
|
||||
batched_input_ids = torch.full((batch_size, input_token_max_len), self.pad_id).long() # FIXME
|
||||
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
|
||||
batched_pixel_values = torch.zeros((batch_size, max_n_images, *self.image_processor.default_shape)).float()
|
||||
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
|
||||
batched_images_emb_mask = torch.zeros((batch_size, max_n_images, self.num_image_tokens)).bool()
|
||||
|
||||
for i, prepare in enumerate(prepare_list):
|
||||
input_ids = prepare.input_ids
|
||||
seq_len = len(prepare)
|
||||
n_image = len(prepare.num_image_tokens)
|
||||
# left-padding
|
||||
batched_attention_mask[i, -seq_len:] = 1
|
||||
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
|
||||
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
|
||||
|
||||
if n_image > 0:
|
||||
batched_pixel_values[i, :n_image] = prepare.pixel_values
|
||||
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
|
||||
batched_images_emb_mask[i, j, :n_image_tokens] = True
|
||||
|
||||
sft_format.append(prepare.sft_format)
|
||||
|
||||
batched_prepares = BatchedVLChatProcessorOutput(
|
||||
input_ids=batched_input_ids,
|
||||
attention_mask=batched_attention_mask,
|
||||
pixel_values=batched_pixel_values,
|
||||
images_seq_mask=batched_images_seq_mask,
|
||||
images_emb_mask=batched_images_emb_mask,
|
||||
sft_format=sft_format
|
||||
)
|
||||
|
||||
return batched_prepares
|
80
deepseek_vlm/models/projector.py
Normal file
80
deepseek_vlm/models/projector.py
Normal file
@ -0,0 +1,80 @@
|
||||
from attrdict import AttrDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Union, Tuple
|
||||
|
||||
|
||||
class MlpProjector(nn.Module):
|
||||
|
||||
def __init__(self, cfg):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.cfg = cfg
|
||||
|
||||
if cfg.projector_type == "identity":
|
||||
modules = nn.Identity()
|
||||
|
||||
elif cfg.projector_type == "linear":
|
||||
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
||||
|
||||
elif cfg.projector_type == "mlp_gelu":
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||||
modules = nn.Sequential(*modules)
|
||||
|
||||
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
||||
mlp_depth = cfg.get("depth", 1)
|
||||
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
||||
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
|
||||
|
||||
modules = []
|
||||
for _ in range(1, mlp_depth):
|
||||
modules.append(nn.GELU())
|
||||
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||||
modules = nn.Sequential(*modules)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
|
||||
|
||||
self.layers = modules
|
||||
|
||||
def forward(self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]):
|
||||
"""
|
||||
|
||||
Args:
|
||||
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
|
||||
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
|
||||
otherwise it is the feature from the single vision encoder.
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): [b, s, c]
|
||||
"""
|
||||
|
||||
if isinstance(x_or_tuple, tuple):
|
||||
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
|
||||
high_x, low_x = x_or_tuple
|
||||
high_x = self.high_up_proj(high_x)
|
||||
low_x = self.low_up_proj(low_x)
|
||||
x = torch.concat([high_x, low_x], dim=-1)
|
||||
else:
|
||||
x = x_or_tuple
|
||||
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = AttrDict(
|
||||
input_dim=1024,
|
||||
n_embed=2048,
|
||||
depth=2,
|
||||
projector_type="low_high_hybrid_split_mlp_gelu"
|
||||
)
|
||||
inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
|
||||
|
||||
m = MlpProjector(cfg)
|
||||
out = m(inputs)
|
||||
print(out.shape)
|
562
deepseek_vlm/models/sam.py
Normal file
562
deepseek_vlm/models/sam.py
Normal file
@ -0,0 +1,562 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
mlp_dim: int,
|
||||
act: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
||||
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
||||
self.act = act()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.lin2(self.act(self.lin1(x)))
|
||||
|
||||
|
||||
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
|
||||
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
|
||||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
class ImageEncoderViT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
downsample_channels: Tuple[int, ...] = (512, 1024),
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
patch_size (int): Patch size.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
depth (int): Depth of ViT.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_abs_pos (bool): If True, use absolute positional embeddings.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks.
|
||||
global_attn_indexes (list): Indexes for blocks using global attention.
|
||||
downsample_channels (list): Channels for downsampling layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.pos_embed: Optional[nn.Parameter] = None
|
||||
if use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size.
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
window_size=window_size if i not in global_attn_indexes else 0,
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
nn.Conv2d(
|
||||
out_chans,
|
||||
out_chans,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
)
|
||||
|
||||
in_channels = out_chans
|
||||
downsamples = []
|
||||
for i in range(len(downsample_channels)):
|
||||
out_channels = downsample_channels[i]
|
||||
downsamples.append(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
in_channels = out_channels
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
self.sam_hd = True
|
||||
if self.sam_hd:
|
||||
self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1))
|
||||
# self.neck_hd = nn.Linear(embed_dim, embed_dim)
|
||||
self.neck_hd = copy.deepcopy(self.neck)
|
||||
# self.downsamples_hd = copy.deepcopy(self.downsamples)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
||||
global_features = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if self.sam_hd and blk.window_size == 0:
|
||||
global_features.append(x)
|
||||
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
x_dtype = x.dtype
|
||||
x = F.interpolate(x.float(), size=(96, 96), mode='bilinear', align_corners=False).to(x_dtype)
|
||||
x = self.downsamples(x)
|
||||
|
||||
if self.sam_hd:
|
||||
first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2))
|
||||
x_dtype = first_global_feature.dtype
|
||||
first_global_feature = F.interpolate(first_global_feature.float(), size=(96, 96), mode='bilinear', align_corners=False)
|
||||
first_global_feature = self.downsamples(first_global_feature.to(x_dtype))
|
||||
x = x + first_global_feature * self.hd_alpha_downsamples
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks. If it equals 0, then
|
||||
use global attention.
|
||||
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||
)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
|
||||
x = self.attn(x)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert (
|
||||
input_size is not None
|
||||
), "Input size must be provided if using relative positional encoding."
|
||||
# initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||
|
||||
def do_attention(q, k, v):
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
if self.use_rel_pos:
|
||||
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
|
||||
return x
|
||||
|
||||
# from haiscale.utils import on_demand_checkpoint
|
||||
# x = on_demand_checkpoint(do_attention, q, k, v)
|
||||
x = do_attention(q, k, v)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Partition into non-overlapping windows with padding if needed.
|
||||
Args:
|
||||
x (tensor): input tokens with [B, H, W, C].
|
||||
window_size (int): window size.
|
||||
|
||||
Returns:
|
||||
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||
(Hp, Wp): padded height and width before partition
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(
|
||||
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
pad_hw (Tuple): padded height and width (Hp, Wp).
|
||||
hw (Tuple): original height and width (H, W) before padding.
|
||||
|
||||
Returns:
|
||||
x: unpartitioned sequences with [B, H, W, C].
|
||||
"""
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of
|
||||
query and key sizes.
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
rel_pos (Tensor): relative position embeddings (L, C).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
||||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
||||
Args:
|
||||
attn (Tensor): attention map.
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
attn (Tensor): attention map with added relative positional embeddings.
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (
|
||||
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
|
||||
).view(B, q_h * q_w, k_h * k_w)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
padding (Tuple): padding size of the projection layer.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
# B C H W -> B H W C
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class SAMViTCfg:
|
||||
image_size: Union[Tuple[int, int], int] = 1024
|
||||
width: int = 1024
|
||||
layers: int = 23
|
||||
heads: int = 16
|
||||
patch_size: int = 16
|
||||
window_size: int = 14
|
||||
prompt_embed_dim: int = 256
|
||||
global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23)
|
||||
downsample_channels: Union[List[int], Tuple[int]] = (512, 1024)
|
||||
|
||||
|
||||
SAM_MODEL_CONFIG = {
|
||||
"sam_vit_b": {
|
||||
"width": 768,
|
||||
"layers": 12,
|
||||
"heads": 12,
|
||||
"global_attn_indexes": [2, 5, 8, 11],
|
||||
"downsample_channels": ()
|
||||
},
|
||||
|
||||
"sam_b_downsample": {
|
||||
"width": 768,
|
||||
"layers": 12,
|
||||
"heads": 12,
|
||||
"global_attn_indexes": [2, 5, 8, 11],
|
||||
"downsample_channels": (512, 1024)
|
||||
},
|
||||
|
||||
"sam_vit_l": {
|
||||
"width": 1024,
|
||||
"layers": 24,
|
||||
"heads": 16,
|
||||
"global_attn_indexes": [5, 11, 17, 23],
|
||||
"downsample_channels": ()
|
||||
},
|
||||
"sam_vit_h": {
|
||||
"width": 1280,
|
||||
"layers": 32,
|
||||
"heads": 16,
|
||||
"global_attn_indexes": [7, 15, 23, 31],
|
||||
"downsample_channels": ()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_sam_vit(
|
||||
model_name: str = "sam_b_downsample",
|
||||
image_size: int = 1024,
|
||||
ckpt_path: str = "",
|
||||
**kwargs
|
||||
):
|
||||
assert model_name in SAM_MODEL_CONFIG.keys(), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}"
|
||||
|
||||
sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name])
|
||||
image_encoder = ImageEncoderViT(
|
||||
depth=sam_cfg.layers,
|
||||
embed_dim=sam_cfg.width,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=sam_cfg.heads,
|
||||
patch_size=sam_cfg.patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=sam_cfg.global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=sam_cfg.prompt_embed_dim,
|
||||
downsample_channels=sam_cfg.downsample_channels
|
||||
)
|
||||
|
||||
if ckpt_path:
|
||||
state_dict = torch.load(ckpt_path)
|
||||
image_encoder.load_state_dict(state_dict, strict=False)
|
||||
print(f"SAM-ViT restores from {ckpt_path}")
|
||||
|
||||
return image_encoder
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
x = torch.zeros(2, 3, 1024, 1024).bfloat16()
|
||||
# x.permute(0, 3, 1, 2)
|
||||
net = create_sam_vit().bfloat16()
|
||||
out = net(x)
|
||||
print(x.shape, out.shape)
|
605
deepseek_vlm/models/siglip_vit.py
Normal file
605
deepseek_vlm/models/siglip_vit.py
Normal file
@ -0,0 +1,605 @@
|
||||
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Final, Optional, Callable, Union, Tuple, List, Set, Dict, Type, Literal, Sequence
|
||||
import math
|
||||
from functools import partial
|
||||
import warnings
|
||||
from timm.layers import (
|
||||
PatchEmbed, Mlp, DropPath,
|
||||
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
|
||||
)
|
||||
from timm.models._manipulate import named_apply, checkpoint_seq
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
|
||||
r""" The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
|
||||
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
|
||||
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
|
||||
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
|
||||
with torch.no_grad():
|
||||
dtype = tensor.dtype
|
||||
tensor_fp32 = tensor.float()
|
||||
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
|
||||
tensor_dtype = tensor_fp32.to(dtype=dtype)
|
||||
tensor.copy_(tensor_dtype)
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
|
||||
trunc_normal_(self.latent, std=self.latent_dim ** -0.5)
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = '') -> None:
|
||||
""" ViT weight initialization, original timm impl (for reproducibility) """
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
# self.fused_attn = use_fused_attn()
|
||||
self.fused_attn = True
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: float = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
proj_drop: float = 0.,
|
||||
attn_drop: float = 0.,
|
||||
init_values: Optional[float] = None,
|
||||
drop_path: float = 0.,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
mlp_layer: nn.Module = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = mlp_layer(
|
||||
in_features=dim,
|
||||
hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=proj_drop,
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
||||
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer
|
||||
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
dynamic_img_size: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
init_values: Optional[float] = None,
|
||||
class_token: bool = True,
|
||||
no_embed_class: bool = False,
|
||||
reg_tokens: int = 0,
|
||||
pre_norm: bool = False,
|
||||
fc_norm: Optional[bool] = None,
|
||||
dynamic_img_size: bool = False,
|
||||
dynamic_img_pad: bool = False,
|
||||
drop_rate: float = 0.,
|
||||
pos_drop_rate: float = 0.,
|
||||
patch_drop_rate: float = 0.,
|
||||
proj_drop_rate: float = 0.,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
|
||||
embed_layer: Callable = PatchEmbed,
|
||||
norm_layer: Optional[LayerType] = None,
|
||||
act_layer: Optional[LayerType] = None,
|
||||
block_fn: Type[nn.Module] = Block,
|
||||
mlp_layer: Type[nn.Module] = Mlp,
|
||||
ignore_head: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
img_size: Input image size.
|
||||
patch_size: Patch size.
|
||||
in_chans: Number of image input channels.
|
||||
num_classes: Mumber of classes for classification head.
|
||||
global_pool: Type of global pooling for final sequence (default: 'token').
|
||||
embed_dim: Transformer embedding dimension.
|
||||
depth: Depth of transformer.
|
||||
num_heads: Number of attention heads.
|
||||
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias: Enable bias for qkv projections if True.
|
||||
init_values: Layer-scale init values (layer-scale enabled if not None).
|
||||
class_token: Use class token.
|
||||
no_embed_class: Don't include position embeddings for class (or reg) tokens.
|
||||
reg_tokens: Number of register tokens.
|
||||
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
|
||||
drop_rate: Head dropout rate.
|
||||
pos_drop_rate: Position embedding dropout rate.
|
||||
attn_drop_rate: Attention dropout rate.
|
||||
drop_path_rate: Stochastic depth rate.
|
||||
weight_init: Weight initialization scheme.
|
||||
embed_layer: Patch embedding layer.
|
||||
norm_layer: Normalization layer.
|
||||
act_layer: MLP activation layer.
|
||||
block_fn: Transformer block layer.
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool in ('', 'avg', 'token', 'map')
|
||||
assert class_token or global_pool != 'token'
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
|
||||
# act_layer = get_act_layer(act_layer) or nn.GELU
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = nn.GELU
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.num_prefix_tokens += reg_tokens
|
||||
self.num_reg_tokens = reg_tokens
|
||||
self.has_class_token = class_token
|
||||
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
|
||||
self.dynamic_img_size = dynamic_img_size
|
||||
self.grad_checkpointing = False
|
||||
self.ignore_head = ignore_head
|
||||
|
||||
embed_args = {}
|
||||
if dynamic_img_size:
|
||||
# flatten deferred until after pos embed
|
||||
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
dynamic_img_pad=dynamic_img_pad,
|
||||
**embed_args,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
||||
if patch_drop_rate > 0:
|
||||
self.patch_drop = PatchDropout(
|
||||
patch_drop_rate,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
)
|
||||
else:
|
||||
self.patch_drop = nn.Identity()
|
||||
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.Sequential(*[
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_norm=qk_norm,
|
||||
init_values=init_values,
|
||||
proj_drop=proj_drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
mlp_layer=mlp_layer,
|
||||
)
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
|
||||
|
||||
# Classifier Head
|
||||
if global_pool == 'map':
|
||||
AttentionPoolLatent.init_weights = init_weights
|
||||
self.attn_pool = AttentionPoolLatent(
|
||||
self.embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
)
|
||||
else:
|
||||
self.attn_pool = None
|
||||
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if weight_init != 'skip':
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None:
|
||||
assert mode in ('jax', 'jax_nlhb', 'moco', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self) -> Set:
|
||||
return {'pos_embed', 'cls_token', 'dist_token'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse: bool = False) -> Dict:
|
||||
return dict(
|
||||
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
||||
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable: bool = True) -> None:
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self) -> nn.Module:
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes: int, global_pool = None) -> None:
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg', 'token', 'map')
|
||||
if global_pool == 'map' and self.attn_pool is None:
|
||||
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
||||
elif global_pool != 'map ' and self.attn_pool is not None:
|
||||
self.attn_pool = None # remove attention pooling
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.dynamic_img_size:
|
||||
B, H, W, C = x.shape
|
||||
pos_embed = resample_abs_pos_embed(
|
||||
self.pos_embed,
|
||||
(H, W),
|
||||
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||
)
|
||||
x = x.view(B, -1, C)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
|
||||
to_cat = []
|
||||
if self.cls_token is not None:
|
||||
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
||||
if self.reg_token is not None:
|
||||
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
||||
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + pos_embed
|
||||
if to_cat:
|
||||
x = torch.cat(to_cat + [x], dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if to_cat:
|
||||
x = torch.cat(to_cat + [x], dim=1)
|
||||
x = x + pos_embed
|
||||
|
||||
return self.pos_drop(x)
|
||||
|
||||
def _intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
) -> List[torch.Tensor]:
|
||||
outputs, num_blocks = [], len(self.blocks)
|
||||
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
|
||||
|
||||
# forward pass
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in take_indices:
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1,
|
||||
reshape: bool = False,
|
||||
return_prefix_tokens: bool = False,
|
||||
norm: bool = False,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
|
||||
Inspired by DINO / DINOv2 interface
|
||||
"""
|
||||
# take last n blocks if n is an int, if in is a sequence, select by matching indices
|
||||
outputs = self._intermediate_layers(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
|
||||
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
|
||||
|
||||
if reshape:
|
||||
grid_size = self.patch_embed.grid_size
|
||||
outputs = [
|
||||
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
|
||||
if return_prefix_tokens:
|
||||
return tuple(zip(outputs, prefix_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.patch_embed(x)
|
||||
x = self._pos_embed(x)
|
||||
x = self.patch_drop(x)
|
||||
x = self.norm_pre(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
|
||||
if self.attn_pool is not None:
|
||||
x = self.attn_pool(x)
|
||||
elif self.global_pool == 'avg':
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1)
|
||||
elif self.global_pool:
|
||||
x = x[:, 0] # class token
|
||||
x = self.fc_norm(x)
|
||||
x = self.head_drop(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.forward_features(x)
|
||||
if not self.ignore_head:
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class SigLIPVisionCfg:
|
||||
width: int = 1152
|
||||
layers: Union[Tuple[int, int, int, int], int] = 27
|
||||
heads: int = 16
|
||||
patch_size: int = 14
|
||||
image_size: Union[Tuple[int, int], int] = 336
|
||||
global_pool: str = "map"
|
||||
mlp_ratio: float = 3.7362
|
||||
class_token: bool = False
|
||||
num_classes: int = 0
|
||||
use_checkpoint: bool = False
|
||||
|
||||
|
||||
SigLIP_MODEL_CONFIG = {
|
||||
"siglip_so400m_patch14_384": {
|
||||
"image_size": 336,
|
||||
"patch_size": 14,
|
||||
"width": 1152,
|
||||
"layers": 27,
|
||||
"heads": 16,
|
||||
"mlp_ratio": 3.7362,
|
||||
"global_pool": "map",
|
||||
"use_checkpoint": False
|
||||
},
|
||||
|
||||
"siglip_so400m_patch14_224": {
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"width": 1152,
|
||||
"layers": 27,
|
||||
"heads": 16,
|
||||
"mlp_ratio": 3.7362,
|
||||
"global_pool": "map",
|
||||
"use_checkpoint": False
|
||||
},
|
||||
|
||||
"siglip_large_patch16_384": {
|
||||
"image_size": 384,
|
||||
"patch_size": 16,
|
||||
"width": 1024,
|
||||
"layers": 24,
|
||||
"heads": 16,
|
||||
"mlp_ratio": 4,
|
||||
"global_pool": "map",
|
||||
"use_checkpoint": False
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_siglip_vit(
|
||||
model_name: str = "siglip_so400m_patch14_384",
|
||||
image_size: int = 384,
|
||||
select_layer: int = -1,
|
||||
ckpt_path: str = "",
|
||||
**kwargs
|
||||
):
|
||||
|
||||
assert model_name in SigLIP_MODEL_CONFIG.keys(), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
|
||||
|
||||
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
|
||||
|
||||
if select_layer <= 0:
|
||||
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
|
||||
else:
|
||||
layers = min(vision_cfg.layers, select_layer)
|
||||
|
||||
model = VisionTransformer(
|
||||
img_size=image_size,
|
||||
patch_size=vision_cfg.patch_size,
|
||||
embed_dim=vision_cfg.width,
|
||||
depth=layers,
|
||||
num_heads=vision_cfg.heads,
|
||||
mlp_ratio=vision_cfg.mlp_ratio,
|
||||
class_token=vision_cfg.class_token,
|
||||
global_pool=vision_cfg.global_pool,
|
||||
ignore_head=kwargs.get("ignore_head", True),
|
||||
weight_init=kwargs.get("weight_init", "skip"),
|
||||
num_classes=0
|
||||
)
|
||||
|
||||
if ckpt_path:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"SigLIP-ViT restores from {ckpt_path},\n"
|
||||
f"\tincompatible_keys:', {incompatible_keys}.")
|
||||
|
||||
return model
|
0
deepseek_vlm/utils/__init__.py
Normal file
0
deepseek_vlm/utils/__init__.py
Normal file
326
deepseek_vlm/utils/conversation.py
Normal file
326
deepseek_vlm/utils/conversation.py
Normal file
@ -0,0 +1,326 @@
|
||||
"""
|
||||
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from enum import IntEnum, auto
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class SeparatorStyle(IntEnum):
|
||||
"""Separator styles."""
|
||||
|
||||
ADD_COLON_SINGLE = auto()
|
||||
ADD_COLON_TWO = auto()
|
||||
ADD_COLON_SPACE_SINGLE = auto()
|
||||
NO_COLON_SINGLE = auto()
|
||||
NO_COLON_TWO = auto()
|
||||
ADD_NEW_LINE_SINGLE = auto()
|
||||
LLAMA2 = auto()
|
||||
CHATGLM = auto()
|
||||
CHATML = auto()
|
||||
CHATINTERN = auto()
|
||||
DOLLY = auto()
|
||||
RWKV = auto()
|
||||
PHOENIX = auto()
|
||||
ROBIN = auto()
|
||||
DeepSeek = auto()
|
||||
PLAIN = auto()
|
||||
ALIGNMENT = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that manages prompt templates and keeps all conversation history."""
|
||||
|
||||
# The name of this template
|
||||
name: str
|
||||
# The template of the system prompt
|
||||
system_template: str = "{system_message}"
|
||||
# The system message
|
||||
system_message: str = ""
|
||||
# The names of two roles
|
||||
roles: List[str] = (("USER", "ASSISTANT"),)
|
||||
# All messages. Each item is (role, message).
|
||||
messages: List[List[str]] = ()
|
||||
# The number of few shot examples
|
||||
offset: int = 0
|
||||
# The separator style and configurations
|
||||
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
|
||||
sep: str = "\n"
|
||||
sep2: str = None
|
||||
# Stop criteria (the default one is EOS token)
|
||||
stop_str: str = None
|
||||
# Stops generation if meeting any token in this list
|
||||
stop_token_ids: List[int] = None
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt for generation."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
|
||||
if self.sep_style == SeparatorStyle.DeepSeek:
|
||||
seps = [self.sep, self.sep2]
|
||||
if system_prompt == "" or system_prompt is None:
|
||||
ret = ""
|
||||
else:
|
||||
ret = system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
ret = system_prompt
|
||||
else:
|
||||
ret = "[INST] "
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
tag = self.roles[i % 2]
|
||||
if message:
|
||||
if type(message) is tuple: # multimodal message
|
||||
message, _ = message
|
||||
if i == 0:
|
||||
ret += message + " "
|
||||
else:
|
||||
ret += tag + " " + message + seps[i % 2]
|
||||
else:
|
||||
ret += tag
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.PLAIN:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message, _, _ = message
|
||||
if i % 2 == 0:
|
||||
ret += message + seps[i % 2]
|
||||
else:
|
||||
ret += message + seps[i % 2]
|
||||
else:
|
||||
ret += ""
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.ALIGNMENT:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = ""
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message, _, _ = message
|
||||
if i % 2 == 0:
|
||||
ret += '<image>\n' + seps[i % 2]
|
||||
else:
|
||||
ret += message + seps[i % 2]
|
||||
else:
|
||||
ret += ""
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def get_prompt_for_current_round(self, content=None):
|
||||
"""Get current round formatted question prompt during sft training"""
|
||||
if self.sep_style == SeparatorStyle.PLAIN:
|
||||
formatted_question = "<image>\n"
|
||||
elif self.sep_style == SeparatorStyle.DeepSeek:
|
||||
formatted_question = f"{self.roles[0]}: " +content.strip() + self.sep + f"{self.roles[1]}:"
|
||||
else:
|
||||
raise ValueError(f"Unsupported sep_style: {self.sep_style}")
|
||||
return formatted_question
|
||||
|
||||
def set_system_message(self, system_message: str):
|
||||
"""Set the system message."""
|
||||
self.system_message = system_message
|
||||
|
||||
def append_message(self, role: str, message: str):
|
||||
"""Append a new message."""
|
||||
self.messages.append([role, message])
|
||||
|
||||
def reset_message(self):
|
||||
"""Reset a new message."""
|
||||
self.messages = []
|
||||
|
||||
def update_last_message(self, message: str):
|
||||
"""Update the last output.
|
||||
|
||||
The last message is typically set to be None when constructing the prompt,
|
||||
so we need to update it in-place after getting the response from a model.
|
||||
"""
|
||||
self.messages[-1][1] = message
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
"""Convert the conversation to gradio chatbot format."""
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def to_openai_api_messages(self):
|
||||
"""Convert the conversation to OpenAI chat completion format."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
ret = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append({"role": "user", "content": msg})
|
||||
else:
|
||||
if msg is not None:
|
||||
ret.append({"role": "assistant", "content": msg})
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
name=self.name,
|
||||
system_template=self.system_template,
|
||||
system_message=self.system_message,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
stop_str=self.stop_str,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"template_name": self.name,
|
||||
"system_message": self.system_message,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
}
|
||||
|
||||
|
||||
# A global registry for all conversation templates
|
||||
conv_templates: Dict[str, Conversation] = {}
|
||||
|
||||
|
||||
def register_conv_template(template: Conversation, override: bool = False):
|
||||
"""Register a new conversation template."""
|
||||
if not override:
|
||||
assert template.name not in conv_templates, f"{template.name} has been registered."
|
||||
|
||||
conv_templates[template.name] = template
|
||||
|
||||
|
||||
def get_conv_template(name: str) -> Conversation:
|
||||
"""Get a conversation template."""
|
||||
return conv_templates[name].copy()
|
||||
|
||||
|
||||
# llava_llama2 template
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llava_llama2",
|
||||
system_message="You are a helpful language and vision assistant. "
|
||||
"You are able to understand the visual content that the user provides, "
|
||||
"and assist the user with a variety of tasks using natural language.",
|
||||
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
||||
roles=("[INST]", "[/INST]"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.LLAMA2,
|
||||
sep=" ",
|
||||
sep2=" </s><s>",
|
||||
stop_token_ids=[2],
|
||||
)
|
||||
)
|
||||
|
||||
# llama2 template
|
||||
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llama-2",
|
||||
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
|
||||
roles=("[INST]", "[/INST]"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.LLAMA2,
|
||||
sep=" ",
|
||||
sep2=" </s><s>",
|
||||
stop_token_ids=[2],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# deepseek template
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="deepseek",
|
||||
system_template="{system_message}",
|
||||
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
|
||||
# "thinking step by step to be sure you get the right answer.",
|
||||
system_message="",
|
||||
roles=("User", "Assistant"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.DeepSeek,
|
||||
sep="\n\n",
|
||||
sep2="<|end▁of▁sentence|>",
|
||||
stop_token_ids=[100001],
|
||||
stop_str=["User:", "<|end▁of▁sentence|>"]
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="plain",
|
||||
system_template="",
|
||||
system_message="",
|
||||
roles=("", ""),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.PLAIN,
|
||||
sep="",
|
||||
sep2="",
|
||||
stop_token_ids=[2],
|
||||
stop_str=['</s>'],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="alignment",
|
||||
system_template="",
|
||||
system_message="",
|
||||
roles=("", ""),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ALIGNMENT,
|
||||
sep="",
|
||||
sep2="",
|
||||
stop_token_ids=[2],
|
||||
stop_str=['</s>'],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# print("Llama-2 template:")
|
||||
# conv = get_conv_template("llama-2")
|
||||
# conv.set_system_message("You are a helpful, respectful and honest assistant.")
|
||||
# conv.append_message(conv.roles[0], "Hello!")
|
||||
# conv.append_message(conv.roles[1], "Hi!")
|
||||
# conv.append_message(conv.roles[0], "How are you?")
|
||||
# conv.append_message(conv.roles[1], None)
|
||||
# print(conv.get_prompt())
|
||||
|
||||
# print("\n")
|
||||
|
||||
print("deepseek template:")
|
||||
conv = get_conv_template("deepseek")
|
||||
conv.append_message(conv.roles[0], "Hello!")
|
||||
conv.append_message(conv.roles[1], "Hi! This is Tony.")
|
||||
conv.append_message(conv.roles[0], "Who are you?")
|
||||
conv.append_message(conv.roles[1], "I am a helpful assistant.")
|
||||
conv.append_message(conv.roles[0], "How are you?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
print(conv.get_prompt())
|
55
deepseek_vlm/utils/io.py
Normal file
55
deepseek_vlm/utils/io.py
Normal file
@ -0,0 +1,55 @@
|
||||
import json
|
||||
import PIL.Image
|
||||
from typing import Dict, List
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM
|
||||
|
||||
|
||||
def load_pretrained_model(model_path: str):
|
||||
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
|
||||
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
||||
|
||||
return tokenizer, vl_chat_processor, vl_gpt
|
||||
|
||||
|
||||
def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
|
||||
[
|
||||
{
|
||||
"role": "User",
|
||||
"content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
|
||||
"images": ["./examples/table_datasets.png"]
|
||||
},
|
||||
{"role": "Assistant", "content": ""},
|
||||
]
|
||||
|
||||
Returns:
|
||||
pil_images (List[PIL.Image.Image]): the list of PIL images.
|
||||
|
||||
"""
|
||||
|
||||
pil_images = []
|
||||
|
||||
for message in conversations:
|
||||
if "images" not in message:
|
||||
continue
|
||||
|
||||
for image_path in message["images"]:
|
||||
pil_img = PIL.Image.open(image_path)
|
||||
pil_img = pil_img.convert("RGB")
|
||||
pil_images.append(pil_img)
|
||||
|
||||
return pil_images
|
||||
|
||||
|
||||
def load_json(filepath):
|
||||
with open(filepath, "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
BIN
images/latex_01.jpg
Normal file
BIN
images/latex_01.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 49 KiB |
BIN
images/monday.jpg
Normal file
BIN
images/monday.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 48 KiB |
BIN
images/training_pipelines.png
Normal file
BIN
images/training_pipelines.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 101 KiB |
53
inference.py
Normal file
53
inference.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM
|
||||
from deepseek_vlm.utils.io import load_pil_images
|
||||
|
||||
|
||||
# specify the path to the model
|
||||
model_path = "deepseek-ai/deepseek-vl-7b-chat"
|
||||
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
|
||||
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "User",
|
||||
"content": "<image_placeholder>Describe each stage of this image.",
|
||||
"images": ["./images/training_pipelines.png"]
|
||||
},
|
||||
{
|
||||
"role": "Assistant",
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# load images and prepare for inputs
|
||||
pil_images = load_pil_images(conversation)
|
||||
prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation,
|
||||
images=pil_images,
|
||||
force_batchify=True
|
||||
).to(vl_gpt.device)
|
||||
|
||||
# run image encoder to get the image embeddings
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
# run the model to get the response
|
||||
outputs = vl_gpt.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=prepare_inputs.attention_mask,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
use_cache=True
|
||||
)
|
||||
|
||||
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
||||
print(f"{prepare_inputs['sft_format'][0]}", answer)
|
@ -1,8 +1,5 @@
|
||||
torch>=2.0
|
||||
tokenizers>=0.14.0
|
||||
transformers>=4.35.0
|
||||
transformers>=4.38.2
|
||||
timm>=0.9.16
|
||||
gradio>=4.13.0
|
||||
accelerate
|
||||
sympy==1.12
|
||||
pebble
|
||||
timeout-decorator
|
||||
attrdict
|
||||
sentencepiece
|
17
setup.py
Normal file
17
setup.py
Normal file
@ -0,0 +1,17 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
|
||||
version = '1.0.0'
|
||||
print(version)
|
||||
|
||||
setup(
|
||||
name='deepseek_vlm',
|
||||
version=version,
|
||||
description='DeekSeel-VLM',
|
||||
author='HFAiLab',
|
||||
license='MIT',
|
||||
url='https://gitlab.deepseek.com/liuwen/deepseek_vl',
|
||||
python_requires='>=3.8',
|
||||
install_requires=['torch>=2.0'],
|
||||
packages=find_packages(exclude=("images",)),
|
||||
)
|
Loading…
Reference in New Issue
Block a user