version 1.0.0

This commit is contained in:
StevenLiuWen 2024-03-08 15:37:17 +08:00
parent e24fd228ab
commit b3e7107168
20 changed files with 2842 additions and 46 deletions

103
README.md
View File

@ -72,16 +72,21 @@ Introducing DeepSeek VL,
## 2. Model Downloads
We release the DeepSeek LLM 7B/67B, including both base and chat models, to the public. To support a broader and more diverse range of research within both academic and commercial communities, we are providing access to the intermediate checkpoints of the base model from its training process. Please **note** that the use of this model is subject to the terms outlined in [License section](#8-license). Commercial usage is permitted under these terms.
We release the DeepSeek-VL family, including 1.3B-base, 1.3B-chat, 7b-base and 7b-chat models, to the public.
To support a broader and more diverse range of research within both academic and commercial communities.
Please note that the use of this model is subject to the terms outlined in [License section](#8-license). Commercial usage is
permitted under these terms.
### Huggingface
| Model | Sequence Length | Download |
|:---------------------:|:---------------:|:-----------------------------------------------------------------------:|
| DeepSeek VL 7B Base | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-llm-7b-base) |
| DeepSeek VL 7B Chat | 4096 | 🤗 [HuggingFace](https://huggingface.co/deepseek-ai/deepseek-llm-7b-chat) |
| Model | Sequence Length | Download |
|-----------------------|-----------------|-----------------------------------------------------------------------------|
| DeepSeek-VL-1.3B-base | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-base) |
| DeepSeek-VL-1.3B-chat | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-1.3b-chat) |
| DeepSeek-VL-7B-base | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-7b-base) |
| DeepSeek-VL-7B-chat | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl-7b-chat) |
## 3. Evaluation Results
# 3. Evaluation Results
### Base Model
@ -139,53 +144,73 @@ We release the training loss curve and several benchmark metrics curves, as deta
On the basis of `Python >= 3.8` environment, install the necessary dependencies by running the following command:
```shell
pip install -r requirements.txt
pip install -r requirements.txt -e .
```
### Inference with Huggingface's Transformers
You can directly employ [Huggingface's Transformers](https://github.com/huggingface/transformers) for model inference.
**Text Completion**
**Simple Inference Example**
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from transformers import AutoModelForCausalLM
model_name = "deepseek-ai/deepseek-llm-67b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vlm.utils.io import load_pil_images
text = "An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is"
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result)
# specify the path to the model
model_path = "deepseek-ai/deepseek-vl-7b-chat"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
conversation = [
{
"role": "User",
"content": "<image_placeholder>Describe each stage of this image.",
"images": ["./images/training_pipelines.png"]
},
{
"role": "Assistant",
"content": ""
}
]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True
).to(vl_gpt.device)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# run the model to get the response
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(f"{prepare_inputs['sft_format'][0]}", answer)
```
**Chat Completion**
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
model_name = "deepseek-ai/deepseek-llm-67b-chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
messages = [
{"role": "user", "content": "Who are you?"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)
**CLI Chat**
```bash
python cli_chat.py --model_path deepseek-ai/deepseek-vl-7b-chat
```
Avoiding the use of the provided function `apply_chat_template`, you can also interact with our model following the sample template. Note that `messages` should be replaced by your input.

194
cli_chat.py Normal file
View File

@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-
import argparse
import os
import sys
from PIL import Image
import readline
from threading import Thread
import torch
from transformers import TextIteratorStreamer
from deepseek_vlm.utils.io import load_pretrained_model
def load_image(image_file):
image = Image.open(image_file).convert("RGB")
return image
def get_help_message(image_token):
help_msg = (
f"\t\t DeepSeek-VL-Chat is a chatbot that can answer questions based on the given image. Enjoy it! \n"
f"Usage: \n"
f" 1. type `exit` to quit. \n"
f" 2. type `{image_token}` to indicate there is an image. You can enter multiple images, "
f"e.g '{image_token} is a dot, {image_token} is a cat, and what is it in {image_token}?'. "
f"When you type `{image_token}`, the chatbot will ask you to input image file path. \n"
f" 4. type `help` to get the help messages. \n"
f" 5. type `new` to start a new conversation. \n"
f" Here is an example, you can type: '<image_placeholder>Describe the image.'\n"
)
return help_msg
@torch.inference_mode()
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):
prompt = conv.get_prompt()
prepare_inputs = vl_chat_processor.__call__(
prompt=prompt,
images=pil_images,
force_batchify=True
).to(vl_gpt.device)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_config["inputs_embeds"] = inputs_embeds
generation_config["attention_mask"] = prepare_inputs.attention_mask
generation_config["streamer"] = streamer
thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config)
thread.start()
yield from streamer
def get_user_input(hint: str):
user_input = ""
while user_input == "":
try:
user_input = input(f"{hint}")
except KeyboardInterrupt:
print()
continue
except EOFError:
user_input = "exit"
return user_input
def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
image_token = vl_chat_processor.image_token
help_msg = get_help_message(image_token)
while True:
print(help_msg)
pil_images = []
conv = vl_chat_processor.new_chat_template()
roles = conv.roles
while True:
# get user input
user_input = get_user_input(f"{roles[0]} [{image_token} indicates an image]: ")
if user_input == "exit":
print("Chat program exited.")
sys.exit(0)
elif user_input == "help":
print(help_msg)
elif user_input == "new":
os.system("clear")
pil_images = []
conv = vl_chat_processor.new_chat_template()
torch.cuda.empty_cache()
print("New conversation started.")
else:
conv.append_message(conv.roles[0], user_input)
conv.append_message(conv.roles[1], None)
# check if the user input is an image token
num_images = user_input.count(image_token)
cur_img_idx = 0
while cur_img_idx < num_images:
try:
image_file = input(f"({cur_img_idx + 1}/{num_images}) Input the image file path: ")
except KeyboardInterrupt:
print()
continue
except EOFError:
image_file = None
if image_file and os.path.exists(image_file):
pil_image = load_image(image_file)
pil_images.append(pil_image)
cur_img_idx += 1
elif image_file == "exit":
print("Chat program exited.")
sys.exit(0)
else:
print(f"File error, `{image_file}` does not exist. Please input the correct file path.")
# get the answer by the model's prediction
answer = ""
answer_iter = response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config)
sys.stdout.write(f"{conv.roles[1]}: ")
for char in answer_iter:
answer += char
sys.stdout.write(char)
sys.stdout.flush()
sys.stdout.write("\n")
sys.stdout.flush()
conv.messages[-1][-1] = answer
def main(args):
# setup
tokenizer, vl_chat_processor, vl_gpt = load_pretrained_model(args.model_path)
generation_config = dict(
pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
max_new_tokens=args.max_gen_len,
use_cache=True,
)
if args.temperature > 0:
generation_config.update({
"do_sample": True,
"top_p": args.top_p,
"temperature": args.temperature,
"repetition_penalty": args.repetition_penalty,
})
else:
generation_config.update({"do_sample": False})
chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl-7b-chat",
help="the huggingface model name or the local path of the downloaded huggingface model.")
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--repetition_penalty", type=float, default=1.1)
parser.add_argument("--max_gen_len", type=int, default=512)
args = parser.parse_args()
main(args)
"""
CUDA_VISIBLE_DEVICES=4 python cli.py --model_path "/home/liuwen/3fs_shared/ckpts/deepseek-vl-7b-chat"
CUDA_VISIBLE_DEVICES=2 python cli.py --model_path "/home/liuwen/3fs_shared/ckpts/siglip_1B_HF"
"""

0
deepseek_vlm/__init__.py Normal file
View File

View File

@ -0,0 +1,5 @@
from .image_processing_vlm import VLMImageProcessor
from .processing_vlm import VLChatProcessor
from .modeling_vlm import MultiModalityCausalLM

View File

@ -0,0 +1,213 @@
from typing import Tuple, Union, List, Dict, Optional, Literal
import torch
import torch.nn as nn
import torchvision.transforms
from einops import rearrange
from deepseek_vlm.models.siglip_vit import create_siglip_vit
from deepseek_vlm.models.sam import create_sam_vit
class CLIPVisionTower(nn.Module):
def __init__(self,
model_name: str = "siglip_large_patch16_384",
image_size: Union[Tuple[int, int], int] = 336,
select_feature: str = "patch",
select_layer: int = -2,
select_layers: list = None,
ckpt_path: str = "",
pixel_mean: Optional[List[float]] = None,
pixel_std: Optional[List[float]] = None,
**kwargs):
super().__init__()
self.model_name = model_name
self.select_feature = select_feature
self.select_layer = select_layer
self.select_layers = select_layers
vision_tower_params = {
"model_name": model_name,
"image_size": image_size,
"ckpt_path": ckpt_path,
"select_layer": select_layer
}
vision_tower_params.update(kwargs)
self.vision_tower, self.forward_kwargs = self.build_vision_tower(vision_tower_params)
if pixel_mean is not None and pixel_std is not None:
image_norm = torchvision.transforms.Normalize(mean=pixel_mean, std=pixel_std)
else:
image_norm = None
self.image_norm = image_norm
def build_vision_tower(self, vision_tower_params):
if self.model_name.startswith("siglip"):
self.select_feature = "same"
vision_tower = create_siglip_vit(**vision_tower_params)
forward_kwargs = dict()
elif self.model_name.startswith("sam"):
vision_tower = create_sam_vit(**vision_tower_params)
forward_kwargs = dict()
else: # huggingface
from transformers import CLIPVisionModel
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
forward_kwargs = dict(output_hidden_states=True)
return vision_tower, forward_kwargs
def feature_select(self, image_forward_outs):
if isinstance(image_forward_outs, torch.Tensor):
# the output has been the self.select_layer"s features
image_features = image_forward_outs
else:
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
# if the output has cls_token
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
elif self.select_feature == "same":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def forward(self, images):
"""
Args:
images (torch.Tensor): [b, 3, H, W]
Returns:
image_features (torch.Tensor): [b, n_patch, d]
"""
if self.image_norm is not None:
images = self.image_norm(images)
image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
image_features = self.feature_select(image_forward_outs)
return image_features
class HybridVisionTower(nn.Module):
def __init__(self,
high_res_cfg: Dict,
low_res_cfg: Dict,
freeze_high: bool = False,
freeze_low: bool = False,
concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple",
**ignore_kwargs):
super().__init__()
self.vision_tower_high = CLIPVisionTower(**high_res_cfg)
self.vision_tower_low = CLIPVisionTower(**low_res_cfg)
self.low_res_size = low_res_cfg["image_size"]
self.concat_type = concat_type
self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024))
self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024))
if freeze_high:
for p_name, p in self.vision_tower_high.named_parameters():
p.requires_grad = False
self.vision_tower_high = self.vision_tower_high.eval()
else:
# train donwsamples and neck
for p_name, p in self.vision_tower_high.named_parameters():
if "downsamples" in p_name or "neck" in p_name:
p.requires_grad = True
else:
p.requires_grad = False
if freeze_low:
for p in self.vision_tower_low.parameters():
p.requires_grad = False
self.vision_tower_low = self.vision_tower_low.eval()
self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True)
def forward(self, images: torch.Tensor):
"""
Args:
images (torch.Tensor): [bs, 3, H, W]
Returns:
res (torch.Tensor): [bs, t, c]
"""
# [bs, c, h, w]
high_images = images
# [bs, c, h_low, w_low]
low_images = self.resize(images)
# separately run two vision towers
# run high_res vision tower
high_res = self.vision_tower_high(high_images)
# [bs, c, h, w] -> [bs, h*w, c]
high_res = rearrange(high_res, "b c h w -> b (h w) c")
# run low_res vision tower
low_res = self.vision_tower_low(low_images)
if self.concat_type == "feature":
images_features = torch.cat([high_res, low_res], dim=-1)
elif self.concat_type == "sequence":
images_features = torch.cat([high_res, low_res], dim=1)
elif self.concat_type == "add":
images_features = high_res + low_res
elif self.concat_type == "tuple":
images_features = (high_res, low_res)
else:
raise ValueError(f"Currently only support `feature`, `sequence`, `add` and `tuple` concat type.")
return images_features
if __name__ == "__main__":
image_size = 1024
x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda()
high_res_cfg = dict(
model_name="sam_b_downsample",
select_feature="same",
image_size=image_size,
pixel_mean=(0.48145466, 0.4578275, 0.40821073),
pixel_std=(0.26862954, 0.26130258, 0.27577711),
select_layer=-1,
ckpt_path=""
)
low_res_cfg = dict(
model_name="siglip_large_patch16_384",
select_feature="same",
image_size=384,
pixel_mean=(0.5, 0.5, 0.5),
pixel_std=(0.5, 0.5, 0.5),
select_layer=-1,
ckpt_path=""
)
net = HybridVisionTower(
high_res_cfg=high_res_cfg,
low_res_cfg=low_res_cfg,
freeze_high=True,
freeze_low=True,
concat_type="tuple"
).bfloat16().cuda()
high_x, low_x = net(x)
print(x.shape, high_x.shape, low_x.shape)

View File

@ -0,0 +1,163 @@
from PIL import Image
import numpy as np
import torch
import torchvision
import torchvision.transforms.functional
from typing import List, Union, Tuple
from transformers import PretrainedConfig, AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_utils import to_numpy_array
from transformers.utils import logging
logger = logging.get_logger(__name__)
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
class VLMImageProcessorConfig(PretrainedConfig):
model_type = "deepseek_vlm"
image_size: int
min_size: int
image_mean: Union[Tuple[float, float, float], List[float]]
image_std: Union[Tuple[float, float, float], List[float]]
rescale_factor: float
do_normalize: bool
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (0.48145466, 0.4578275, 0.40821073),
image_std: Union[Tuple[float, float, float], List[float]] = (0.26862954, 0.26130258, 0.27577711),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True, **kwargs
):
self.image_size = image_size
self.min_size = min_size
self.image_mean = image_mean
self.image_std = image_std
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
super().__init__(**kwargs)
class VLMImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
image_size: int,
min_size: int = 14,
image_mean: Union[Tuple[float, float, float], List[float]] = (0.48145466, 0.4578275, 0.40821073),
image_std: Union[Tuple[float, float, float], List[float]] = (0.26862954, 0.26130258, 0.27577711),
rescale_factor: float = 1.0 / 255.0,
do_normalize: bool = True, **kwargs
):
super().__init__(**kwargs)
self.image_size = image_size
self.rescale_factor = rescale_factor
self.image_mean = image_mean
self.image_std = image_std
self.min_size = min_size
self.do_normalize = do_normalize
if image_mean is None:
self.background_color = (127, 127, 127)
else:
self.background_color = tuple([int(x * 255) for x in image_mean])
def resize(self, pil_img: Image) -> np.ndarray:
"""
Args:
pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
Returns:
x (np.ndarray): [3, self.image_size, self.image_size]
"""
width, height = pil_img.size
max_size = max(width, height)
size = [
max(int(height / max_size * self.image_size), self.min_size),
max(int(width / max_size * self.image_size), self.min_size)
]
if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
print(f"orig size = {pil_img.size}, new size = {size}")
raise ValueError("Invalid size!")
pil_img = torchvision.transforms.functional.resize(
pil_img, size, interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, antialias=True
)
pil_img = expand2square(pil_img, self.background_color)
x = to_numpy_array(pil_img)
# [H, W, 3] -> [3, H, W]
x = np.transpose(x, (2, 0, 1))
return x
def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
# resize and pad to [self.image_size, self.image_size]
# then convert from [H, W, 3] to [3, H, W]
images: List[np.ndarray] = [self.resize(image) for image in images]
# resacle from [0, 255] -> [0, 1]
images = [
self.rescale(image=image, scale=self.rescale_factor, input_data_format="channels_first")
for image in images
]
# normalize
if self.do_normalize:
images = [
self.normalize(image=image, mean=self.image_mean, std=self.image_std,
input_data_format="channels_first")
for image in images
]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
@property
def default_shape(self):
return [3, self.image_size, self.image_size]
# AutoConfig.register("deepseek_vlm", VLMImageProcessorConfig)
AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
if __name__ == "__main__":
image_processor = VLMImageProcessor(
image_size=1024,
image_mean=IMAGENET_INCEPTION_MEAN,
image_std=IMAGENET_INCEPTION_STD,
do_normalize=True
)

View File

@ -0,0 +1,150 @@
from attrdict import AttrDict
from einops import rearrange
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedModel,
LlamaConfig,
LlamaForCausalLM
)
from deepseek_vlm.models.projector import MlpProjector
from deepseek_vlm.models.clip_encoder import CLIPVisionTower, HybridVisionTower
def model_name_to_cls(cls_name):
if "MlpProjector" in cls_name:
cls = MlpProjector
elif "CLIPVisionTower" in cls_name:
cls = CLIPVisionTower
elif "HybridVisionTower" in cls_name:
cls = HybridVisionTower
else:
raise ValueError(f"class_name {cls_name} is invalid.")
return cls
class VisionConfig(PretrainedConfig):
model_type = "vision"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class AlignerConfig(PretrainedConfig):
model_type = "aligner"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig
language_config: LlamaConfig
def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionConfig(**vision_config)
aligner_config = kwargs.get("aligner_config", {})
self.aligner_config = AlignerConfig(**aligner_config)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
else:
self.language_config = LlamaConfig(**language_config)
class MultiModalityPreTrainedModel(PreTrainedModel):
config_class = MultiModalityConfig
base_model_prefix = "multi_modality"
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def __init__(self, config: MultiModalityConfig):
super().__init__(config)
vision_config = config.vision_config
vision_cls = model_name_to_cls(vision_config.cls)
self.vision_model = vision_cls(**vision_config.params)
aligner_config = config.aligner_config
aligner_cls = model_name_to_cls(aligner_config.cls)
self.aligner = aligner_cls(aligner_config.params)
language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config)
def prepare_inputs_embeds(self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.LongTensor, **kwargs):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs, n = pixel_values.shape[0:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
input_ids[input_ids < 0] = 0 # ignore the image embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
# replace with the image embeddings
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds
AutoConfig.register("vision", VisionConfig)
AutoConfig.register("aligner", AlignerConfig)
AutoConfig.register("multi_modality", MultiModalityConfig)
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)

View File

@ -0,0 +1,351 @@
from dataclasses import dataclass
import numpy as np
from PIL.Image import Image
from typing import List, Dict, Union
import torch
from transformers import AutoTokenizer, AutoImageProcessor
from transformers.processing_utils import ProcessorMixin
from transformers import LlamaTokenizerFast
from deepseek_vlm.models.image_processing_vlm import VLMImageProcessor
from deepseek_vlm.utils.conversation import get_conv_template
class DictOutput(object):
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
sft_format: str
input_ids: torch.Tensor
pixel_values: torch.Tensor
num_image_tokens: torch.IntTensor
def __len__(self):
return len(self.input_ids)
@dataclass
class BatchedVLChatProcessorOutput(DictOutput):
sft_format: List[str]
input_ids: torch.Tensor
pixel_values: torch.Tensor
attention_mask: torch.Tensor
images_seq_mask: torch.BoolTensor
images_emb_mask: torch.BoolTensor
def to(self, device, dtype=torch.bfloat16):
self.input_ids = self.input_ids.to(device)
self.attention_mask = self.attention_mask.to(device)
self.images_seq_mask = self.images_seq_mask.to(device)
self.images_emb_mask = self.images_emb_mask.to(device)
self.pixel_values = self.pixel_values.to(device=device, dtype=dtype)
return self
class VLChatProcessor(ProcessorMixin):
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["image_processor", "tokenizer"]
system_prompt = ("You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.")
def __init__(
self,
image_processor: VLMImageProcessor,
tokenizer: LlamaTokenizerFast,
image_tag: str = "<image_placeholder>",
num_image_tokens: int = 576,
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100, **kwargs
):
self.image_processor = image_processor
self.tokenizer = tokenizer
image_id = self.tokenizer.vocab.get(image_tag)
if image_id is None:
special_tokens = [image_tag]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
print(f"Add image tag = {image_tag} to the tokenizer")
self.image_tag = image_tag
self.num_image_tokens = num_image_tokens
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(image_processor, tokenizer, image_tag, num_image_tokens, add_special_token,
sft_format, mask_prompt, ignore_id, **kwargs)
def new_chat_template(self):
conv = get_conv_template(self.sft_format)
conv.set_system_message(self.system_prompt)
return conv
def apply_sft_template_for_multi_turn_prompts(
self,
conversations: List[Dict[str, str]],
sft_format: str = "deepseek",
system_prompt: str = ""
):
"""
Applies the SFT template to conversation.
An example of conversation:
conversation = [
{
"role": "User",
"content": "<image_placeholder> is Figure 1.\n<image_placeholder> is Figure 2.\nWhich image is brighter?",
"images": [
"./multi-images/attribute_comparison_1.png",
"./multi-images/attribute_comparison_2.png"
]
},
{
"role": "Assistant",
"content": ""
}
]
Args:
conversations (List[Dict]): A conversation with a List of Dict[str, str] text.
sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
Returns:
sft_prompt (str): The formatted text.
"""
conv = get_conv_template(sft_format)
conv.set_system_message(system_prompt)
for message in conversations:
conv.append_message(message["role"], message["content"].strip())
sft_prompt = conv.get_prompt().strip()
return sft_prompt
@property
def image_token(self):
return self.image_tag
@property
def image_id(self):
image_id = self.tokenizer.vocab.get(self.image_tag)
return image_id
@property
def pad_id(self):
pad_id = self.tokenizer.pad_token_id
if pad_id is None:
pad_id = self.tokenizer.eos_token_id
return pad_id
def add_image_token(
self,
image_indices: List[int],
input_ids: torch.LongTensor,
):
"""
Args:
image_indices (List[int]): [index_0, index_1, ..., index_j]
input_ids (torch.LongTensor): [N]
Returns:
input_ids (torch.LongTensor): [N + image tokens]
num_image_tokens (torch.IntTensor): [n_images]
"""
input_slices = []
start = 0
for index in image_indices:
if self.add_special_token:
end = index + 1
else:
end = index
# original text tokens
input_slices.append(input_ids[start: end])
# add image tokens, and set the mask as False
input_slices.append(self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long))
start = index + 1
# the left part
input_slices.append(input_ids[start:])
# concat all slices
input_ids = torch.cat(input_slices, dim=0)
num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
return input_ids, num_image_tokens
def process_one(
self,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image] = None,
**kwargs
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
assert prompt is None or conversations is None, "prompt and conversations cannot be used at the same time."
if prompt is None:
# apply sft format
sft_format = self.apply_sft_template_for_multi_turn_prompts(
conversations=conversations,
sft_format=self.sft_format,
system_prompt=self.system_prompt
)
else:
sft_format = prompt
# tokenize
input_ids = self.tokenizer.encode(sft_format)
input_ids = torch.LongTensor(input_ids)
# add image tokens to the input_ids
image_token_mask: torch.BoolTensor = input_ids == self.image_id
image_indices = image_token_mask.nonzero()
input_ids, num_image_tokens = self.add_image_token(
image_indices=image_indices,
input_ids=input_ids,
)
# load images
images_outputs = self.image_processor(images, return_tensors="pt")
prepare = VLChatProcessorOutput(
sft_format=sft_format,
input_ids=input_ids,
pixel_values=images_outputs.pixel_values,
num_image_tokens=num_image_tokens
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image] = None,
force_batchify: bool = True,
**kwargs
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
force_batchify (bool): force batchify the inputs;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prepare = self.process_one(prompt=prompt, conversations=conversations, images=images)
if force_batchify:
prepare = self.batchify([prepare])
return prepare
def batchify(self, prepare_list: List[VLChatProcessorOutput]) -> BatchedVLChatProcessorOutput:
"""
Preprocesses the inputs for multimodal inference.
Args:
prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
Returns:
BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
"""
batch_size = len(prepare_list)
sft_format = []
n_images = []
seq_lens = []
for prepare in prepare_list:
n_images.append(len(prepare.num_image_tokens))
seq_lens.append(len(prepare))
input_token_max_len = max(seq_lens)
max_n_images = max(1, max(n_images))
batched_input_ids = torch.full((batch_size, input_token_max_len), self.pad_id).long() # FIXME
batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
batched_pixel_values = torch.zeros((batch_size, max_n_images, *self.image_processor.default_shape)).float()
batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
batched_images_emb_mask = torch.zeros((batch_size, max_n_images, self.num_image_tokens)).bool()
for i, prepare in enumerate(prepare_list):
input_ids = prepare.input_ids
seq_len = len(prepare)
n_image = len(prepare.num_image_tokens)
# left-padding
batched_attention_mask[i, -seq_len:] = 1
batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
if n_image > 0:
batched_pixel_values[i, :n_image] = prepare.pixel_values
for j, n_image_tokens in enumerate(prepare.num_image_tokens):
batched_images_emb_mask[i, j, :n_image_tokens] = True
sft_format.append(prepare.sft_format)
batched_prepares = BatchedVLChatProcessorOutput(
input_ids=batched_input_ids,
attention_mask=batched_attention_mask,
pixel_values=batched_pixel_values,
images_seq_mask=batched_images_seq_mask,
images_emb_mask=batched_images_emb_mask,
sft_format=sft_format
)
return batched_prepares

View File

@ -0,0 +1,80 @@
from attrdict import AttrDict
import torch
import torch.nn as nn
from typing import Union, Tuple
class MlpProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.projector_type == "identity":
modules = nn.Identity()
elif cfg.projector_type == "linear":
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
elif cfg.projector_type == "mlp_gelu":
mlp_depth = cfg.get("depth", 1)
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
mlp_depth = cfg.get("depth", 1)
self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
modules = []
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
else:
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
self.layers = modules
def forward(self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]):
"""
Args:
x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
otherwise it is the feature from the single vision encoder.
Returns:
x (torch.Tensor): [b, s, c]
"""
if isinstance(x_or_tuple, tuple):
# self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
high_x, low_x = x_or_tuple
high_x = self.high_up_proj(high_x)
low_x = self.low_up_proj(low_x)
x = torch.concat([high_x, low_x], dim=-1)
else:
x = x_or_tuple
return self.layers(x)
if __name__ == "__main__":
cfg = AttrDict(
input_dim=1024,
n_embed=2048,
depth=2,
projector_type="low_high_hybrid_split_mlp_gelu"
)
inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
m = MlpProjector(cfg)
out = m(inputs)
print(out.shape)

562
deepseek_vlm/models/sam.py Normal file
View File

@ -0,0 +1,562 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: Tuple[int, ...] = (),
downsample_channels: Tuple[int, ...] = (512, 1024),
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
downsample_channels (list): Channels for downsampling layers.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
in_channels = out_chans
downsamples = []
for i in range(len(downsample_channels)):
out_channels = downsample_channels[i]
downsamples.append(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
)
)
in_channels = out_channels
self.downsamples = nn.Sequential(*downsamples)
self.sam_hd = True
if self.sam_hd:
self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1))
# self.neck_hd = nn.Linear(embed_dim, embed_dim)
self.neck_hd = copy.deepcopy(self.neck)
# self.downsamples_hd = copy.deepcopy(self.downsamples)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
global_features = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if self.sam_hd and blk.window_size == 0:
global_features.append(x)
x = self.neck(x.permute(0, 3, 1, 2))
x_dtype = x.dtype
x = F.interpolate(x.float(), size=(96, 96), mode='bilinear', align_corners=False).to(x_dtype)
x = self.downsamples(x)
if self.sam_hd:
first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2))
x_dtype = first_global_feature.dtype
first_global_feature = F.interpolate(first_global_feature.float(), size=(96, 96), mode='bilinear', align_corners=False)
first_global_feature = self.downsamples(first_global_feature.to(x_dtype))
x = x + first_global_feature * self.hd_alpha_downsamples
return x
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
input_size=input_size if window_size == 0 else (window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
def do_attention(q, k, v):
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
return x
# from haiscale.utils import on_demand_checkpoint
# x = on_demand_checkpoint(do_attention, q, k, v)
x = do_attention(q, k, v)
x = self.proj(x)
return x
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16),
stride: Tuple[int, int] = (16, 16),
padding: Tuple[int, int] = (0, 0),
in_chans: int = 3,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
@dataclass
class SAMViTCfg:
image_size: Union[Tuple[int, int], int] = 1024
width: int = 1024
layers: int = 23
heads: int = 16
patch_size: int = 16
window_size: int = 14
prompt_embed_dim: int = 256
global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23)
downsample_channels: Union[List[int], Tuple[int]] = (512, 1024)
SAM_MODEL_CONFIG = {
"sam_vit_b": {
"width": 768,
"layers": 12,
"heads": 12,
"global_attn_indexes": [2, 5, 8, 11],
"downsample_channels": ()
},
"sam_b_downsample": {
"width": 768,
"layers": 12,
"heads": 12,
"global_attn_indexes": [2, 5, 8, 11],
"downsample_channels": (512, 1024)
},
"sam_vit_l": {
"width": 1024,
"layers": 24,
"heads": 16,
"global_attn_indexes": [5, 11, 17, 23],
"downsample_channels": ()
},
"sam_vit_h": {
"width": 1280,
"layers": 32,
"heads": 16,
"global_attn_indexes": [7, 15, 23, 31],
"downsample_channels": ()
},
}
def create_sam_vit(
model_name: str = "sam_b_downsample",
image_size: int = 1024,
ckpt_path: str = "",
**kwargs
):
assert model_name in SAM_MODEL_CONFIG.keys(), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}"
sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name])
image_encoder = ImageEncoderViT(
depth=sam_cfg.layers,
embed_dim=sam_cfg.width,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=sam_cfg.heads,
patch_size=sam_cfg.patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=sam_cfg.global_attn_indexes,
window_size=14,
out_chans=sam_cfg.prompt_embed_dim,
downsample_channels=sam_cfg.downsample_channels
)
if ckpt_path:
state_dict = torch.load(ckpt_path)
image_encoder.load_state_dict(state_dict, strict=False)
print(f"SAM-ViT restores from {ckpt_path}")
return image_encoder
if __name__ == '__main__':
x = torch.zeros(2, 3, 1024, 1024).bfloat16()
# x.permute(0, 3, 1, 2)
net = create_sam_vit().bfloat16()
out = net(x)
print(x.shape, out.shape)

View File

@ -0,0 +1,605 @@
# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Final, Optional, Callable, Union, Tuple, List, Set, Dict, Type, Literal, Sequence
import math
from functools import partial
import warnings
from timm.layers import (
PatchEmbed, Mlp, DropPath,
AttentionPoolLatent, PatchDropout, resample_abs_pos_embed, LayerType
)
from timm.models._manipulate import named_apply, checkpoint_seq
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (torch.Tensor, float, float, float, float) -> torch.Tensor
r""" The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
with torch.no_grad():
dtype = tensor.dtype
tensor_fp32 = tensor.float()
tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
tensor_dtype = tensor_fp32.to(dtype=dtype)
tensor.copy_(tensor_dtype)
def init_weights(self):
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
trunc_normal_(self.latent, std=self.latent_dim ** -0.5)
def init_weights_vit_timm(module: nn.Module, name: str = '') -> None:
""" ViT weight initialization, original timm impl (for reproducibility) """
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
# self.fused_attn = use_fused_attn()
self.fused_attn = True
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0. else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
proj_drop: float = 0.,
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: nn.Module = nn.GELU,
norm_layer: nn.Module = nn.LayerNorm,
mlp_layer: nn.Module = Mlp,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
dynamic_img_size: Final[bool]
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.,
qkv_bias: bool = True,
qk_norm: bool = False,
init_values: Optional[float] = None,
class_token: bool = True,
no_embed_class: bool = False,
reg_tokens: int = 0,
pre_norm: bool = False,
fc_norm: Optional[bool] = None,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
drop_rate: float = 0.,
pos_drop_rate: float = 0.,
patch_drop_rate: float = 0.,
proj_drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
embed_layer: Callable = PatchEmbed,
norm_layer: Optional[LayerType] = None,
act_layer: Optional[LayerType] = None,
block_fn: Type[nn.Module] = Block,
mlp_layer: Type[nn.Module] = Mlp,
ignore_head: bool = False
) -> None:
"""
Args:
img_size: Input image size.
patch_size: Patch size.
in_chans: Number of image input channels.
num_classes: Mumber of classes for classification head.
global_pool: Type of global pooling for final sequence (default: 'token').
embed_dim: Transformer embedding dimension.
depth: Depth of transformer.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: Enable bias for qkv projections if True.
init_values: Layer-scale init values (layer-scale enabled if not None).
class_token: Use class token.
no_embed_class: Don't include position embeddings for class (or reg) tokens.
reg_tokens: Number of register tokens.
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
drop_rate: Head dropout rate.
pos_drop_rate: Position embedding dropout rate.
attn_drop_rate: Attention dropout rate.
drop_path_rate: Stochastic depth rate.
weight_init: Weight initialization scheme.
embed_layer: Patch embedding layer.
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
"""
super().__init__()
assert global_pool in ('', 'avg', 'token', 'map')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
# norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
# act_layer = get_act_layer(act_layer) or nn.GELU
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
self.num_reg_tokens = reg_tokens
self.has_class_token = class_token
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
self.dynamic_img_size = dynamic_img_size
self.grad_checkpointing = False
self.ignore_head = ignore_head
embed_args = {}
if dynamic_img_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
self.patch_drop = PatchDropout(
patch_drop_rate,
num_prefix_tokens=self.num_prefix_tokens,
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
mlp_layer=mlp_layer,
)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
if global_pool == 'map':
AttentionPoolLatent.init_weights = init_weights
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
)
else:
self.attn_pool = None
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)
def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None:
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
@torch.jit.ignore
def no_weight_decay(self) -> Set:
return {'pos_embed', 'cls_token', 'dist_token'}
@torch.jit.ignore
def group_matcher(self, coarse: bool = False) -> Dict:
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True) -> None:
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self) -> nn.Module:
return self.head
def reset_classifier(self, num_classes: int, global_pool = None) -> None:
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map ' and self.attn_pool is not None:
self.attn_pool = None # remove attention pooling
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(B, -1, C)
else:
pos_embed = self.pos_embed
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
if self.reg_token is not None:
to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + pos_embed
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
x = x + pos_embed
return self.pos_drop(x)
def _intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
) -> List[torch.Tensor]:
outputs, num_blocks = [], len(self.blocks)
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n)
# forward pass
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in take_indices:
outputs.append(x)
return outputs
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1,
reshape: bool = False,
return_prefix_tokens: bool = False,
norm: bool = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
""" Intermediate layer accessor (NOTE: This is a WIP experiment).
Inspired by DINO / DINOv2 interface
"""
# take last n blocks if n is an int, if in is a sequence, select by matching indices
outputs = self._intermediate_layers(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs]
outputs = [out[:, self.num_prefix_tokens:] for out in outputs]
if reshape:
grid_size = self.patch_embed.grid_size
outputs = [
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_prefix_tokens:
return tuple(zip(outputs, prefix_tokens))
return tuple(outputs)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.patch_drop(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
if self.attn_pool is not None:
x = self.attn_pool(x)
elif self.global_pool == 'avg':
x = x[:, self.num_prefix_tokens:].mean(dim=1)
elif self.global_pool:
x = x[:, 0] # class token
x = self.fc_norm(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
if not self.ignore_head:
x = self.forward_head(x)
return x
@dataclass
class SigLIPVisionCfg:
width: int = 1152
layers: Union[Tuple[int, int, int, int], int] = 27
heads: int = 16
patch_size: int = 14
image_size: Union[Tuple[int, int], int] = 336
global_pool: str = "map"
mlp_ratio: float = 3.7362
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
SigLIP_MODEL_CONFIG = {
"siglip_so400m_patch14_384": {
"image_size": 336,
"patch_size": 14,
"width": 1152,
"layers": 27,
"heads": 16,
"mlp_ratio": 3.7362,
"global_pool": "map",
"use_checkpoint": False
},
"siglip_so400m_patch14_224": {
"image_size": 224,
"patch_size": 14,
"width": 1152,
"layers": 27,
"heads": 16,
"mlp_ratio": 3.7362,
"global_pool": "map",
"use_checkpoint": False
},
"siglip_large_patch16_384": {
"image_size": 384,
"patch_size": 16,
"width": 1024,
"layers": 24,
"heads": 16,
"mlp_ratio": 4,
"global_pool": "map",
"use_checkpoint": False
}
}
def create_siglip_vit(
model_name: str = "siglip_so400m_patch14_384",
image_size: int = 384,
select_layer: int = -1,
ckpt_path: str = "",
**kwargs
):
assert model_name in SigLIP_MODEL_CONFIG.keys(), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
if select_layer <= 0:
layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
else:
layers = min(vision_cfg.layers, select_layer)
model = VisionTransformer(
img_size=image_size,
patch_size=vision_cfg.patch_size,
embed_dim=vision_cfg.width,
depth=layers,
num_heads=vision_cfg.heads,
mlp_ratio=vision_cfg.mlp_ratio,
class_token=vision_cfg.class_token,
global_pool=vision_cfg.global_pool,
ignore_head=kwargs.get("ignore_head", True),
weight_init=kwargs.get("weight_init", "skip"),
num_classes=0
)
if ckpt_path:
state_dict = torch.load(ckpt_path, map_location="cpu")
incompatible_keys = model.load_state_dict(state_dict, strict=False)
print(f"SigLIP-ViT restores from {ckpt_path},\n"
f"\tincompatible_keys:', {incompatible_keys}.")
return model

View File

View File

@ -0,0 +1,326 @@
"""
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
"""
import dataclasses
from enum import IntEnum, auto
from typing import Dict, List
class SeparatorStyle(IntEnum):
"""Separator styles."""
ADD_COLON_SINGLE = auto()
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
DOLLY = auto()
RWKV = auto()
PHOENIX = auto()
ROBIN = auto()
DeepSeek = auto()
PLAIN = auto()
ALIGNMENT = auto()
@dataclasses.dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""
# The name of this template
name: str
# The template of the system prompt
system_template: str = "{system_message}"
# The system message
system_message: str = ""
# The names of two roles
roles: List[str] = (("USER", "ASSISTANT"),)
# All messages. Each item is (role, message).
messages: List[List[str]] = ()
# The number of few shot examples
offset: int = 0
# The separator style and configurations
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
sep: str = "\n"
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: str = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.DeepSeek:
seps = [self.sep, self.sep2]
if system_prompt == "" or system_prompt is None:
ret = ""
else:
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
ret = system_prompt
else:
ret = "[INST] "
for i, (role, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
if type(message) is tuple: # multimodal message
message, _ = message
if i == 0:
ret += message + " "
else:
ret += tag + " " + message + seps[i % 2]
else:
ret += tag
return ret
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
if i % 2 == 0:
ret += message + seps[i % 2]
else:
ret += message + seps[i % 2]
else:
ret += ""
return ret
elif self.sep_style == SeparatorStyle.ALIGNMENT:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
if i % 2 == 0:
ret += '<image>\n' + seps[i % 2]
else:
ret += message + seps[i % 2]
else:
ret += ""
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def get_prompt_for_current_round(self, content=None):
"""Get current round formatted question prompt during sft training"""
if self.sep_style == SeparatorStyle.PLAIN:
formatted_question = "<image>\n"
elif self.sep_style == SeparatorStyle.DeepSeek:
formatted_question = f"{self.roles[0]}: " +content.strip() + self.sep + f"{self.roles[1]}:"
else:
raise ValueError(f"Unsupported sep_style: {self.sep_style}")
return formatted_question
def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message
def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])
def reset_message(self):
"""Reset a new message."""
self.messages = []
def update_last_message(self, message: str):
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
"""
self.messages[-1][1] = message
def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
system_prompt = self.system_template.format(system_message=self.system_message)
ret = [{"role": "system", "content": system_prompt}]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
if msg is not None:
ret.append({"role": "assistant", "content": msg})
return ret
def copy(self):
return Conversation(
name=self.name,
system_template=self.system_template,
system_message=self.system_message,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
)
def dict(self):
return {
"template_name": self.name,
"system_message": self.system_message,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
}
# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}
def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert template.name not in conv_templates, f"{template.name} has been registered."
conv_templates[template.name] = template
def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name].copy()
# llava_llama2 template
register_conv_template(
Conversation(
name="llava_llama2",
system_message="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
)
)
# llama2 template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
register_conv_template(
Conversation(
name="llama-2",
system_template="[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
)
)
# deepseek template
register_conv_template(
Conversation(
name="deepseek",
system_template="{system_message}",
# system_message="You are a helpful assistant. Please answer truthfully and write out your "
# "thinking step by step to be sure you get the right answer.",
system_message="",
roles=("User", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.DeepSeek,
sep="\n\n",
sep2="<end▁of▁sentence>",
stop_token_ids=[100001],
stop_str=["User:", "<end▁of▁sentence>"]
)
)
register_conv_template(
Conversation(
name="plain",
system_template="",
system_message="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="",
sep2="",
stop_token_ids=[2],
stop_str=['</s>'],
)
)
register_conv_template(
Conversation(
name="alignment",
system_template="",
system_message="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.ALIGNMENT,
sep="",
sep2="",
stop_token_ids=[2],
stop_str=['</s>'],
)
)
if __name__ == "__main__":
# print("Llama-2 template:")
# conv = get_conv_template("llama-2")
# conv.set_system_message("You are a helpful, respectful and honest assistant.")
# conv.append_message(conv.roles[0], "Hello!")
# conv.append_message(conv.roles[1], "Hi!")
# conv.append_message(conv.roles[0], "How are you?")
# conv.append_message(conv.roles[1], None)
# print(conv.get_prompt())
# print("\n")
print("deepseek template:")
conv = get_conv_template("deepseek")
conv.append_message(conv.roles[0], "Hello!")
conv.append_message(conv.roles[1], "Hi! This is Tony.")
conv.append_message(conv.roles[0], "Who are you?")
conv.append_message(conv.roles[1], "I am a helpful assistant.")
conv.append_message(conv.roles[0], "How are you?")
conv.append_message(conv.roles[1], None)
print(conv.get_prompt())

55
deepseek_vlm/utils/io.py Normal file
View File

@ -0,0 +1,55 @@
import json
import PIL.Image
from typing import Dict, List
import torch
from transformers import AutoModelForCausalLM
from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM
def load_pretrained_model(model_path: str):
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
return tokenizer, vl_chat_processor, vl_gpt
def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
"""
Args:
conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
[
{
"role": "User",
"content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
"images": ["./examples/table_datasets.png"]
},
{"role": "Assistant", "content": ""},
]
Returns:
pil_images (List[PIL.Image.Image]): the list of PIL images.
"""
pil_images = []
for message in conversations:
if "images" not in message:
continue
for image_path in message["images"]:
pil_img = PIL.Image.open(image_path)
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
return pil_images
def load_json(filepath):
with open(filepath, "r") as f:
data = json.load(f)
return data

BIN
images/latex_01.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

BIN
images/monday.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

53
inference.py Normal file
View File

@ -0,0 +1,53 @@
import torch
from transformers import AutoModelForCausalLM
from deepseek_vlm.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vlm.utils.io import load_pil_images
# specify the path to the model
model_path = "deepseek-ai/deepseek-vl-7b-chat"
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
conversation = [
{
"role": "User",
"content": "<image_placeholder>Describe each stage of this image.",
"images": ["./images/training_pipelines.png"]
},
{
"role": "Assistant",
"content": ""
}
]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True
).to(vl_gpt.device)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
# run the model to get the response
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(f"{prepare_inputs['sft_format'][0]}", answer)

View File

@ -1,8 +1,5 @@
torch>=2.0
tokenizers>=0.14.0
transformers>=4.35.0
transformers>=4.38.2
timm>=0.9.16
gradio>=4.13.0
accelerate
sympy==1.12
pebble
timeout-decorator
attrdict
sentencepiece

17
setup.py Normal file
View File

@ -0,0 +1,17 @@
from setuptools import setup, find_packages
version = '1.0.0'
print(version)
setup(
name='deepseek_vlm',
version=version,
description='DeekSeel-VLM',
author='HFAiLab',
license='MIT',
url='https://gitlab.deepseek.com/liuwen/deepseek_vl',
python_requires='>=3.8',
install_requires=['torch>=2.0'],
packages=find_packages(exclude=("images",)),
)