mirror of
https://github.com/deepseek-ai/DeepSeek-VL2.git
synced 2025-02-23 06:09:04 -05:00
473 lines
18 KiB
Python
473 lines
18 KiB
Python
from attrdict import AttrDict
|
||
from einops import rearrange, repeat
|
||
from typing import Optional, List, Tuple, Callable, Union
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
from transformers.configuration_utils import PretrainedConfig
|
||
from transformers import (
|
||
AutoConfig,
|
||
AutoModelForCausalLM,
|
||
PreTrainedModel, GenerationConfig, LogitsProcessorList, StoppingCriteriaList,
|
||
)
|
||
from transformers.generation.utils import GenerateOutput
|
||
|
||
from .siglip_vit import VisionTransformer
|
||
from .configuration_deepseek import DeepseekV2Config
|
||
from .modeling_deepseek import DeepseekV2ForCausalLM
|
||
|
||
|
||
class MlpProjector(nn.Module):
|
||
|
||
def __init__(self, cfg):
|
||
|
||
super().__init__()
|
||
|
||
self.cfg = cfg
|
||
|
||
if cfg.projector_type == "identity":
|
||
modules = nn.Identity()
|
||
|
||
elif cfg.projector_type == "linear":
|
||
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
|
||
|
||
elif cfg.projector_type == "mlp_gelu":
|
||
mlp_depth = cfg.depth
|
||
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
|
||
for _ in range(1, mlp_depth):
|
||
modules.append(nn.GELU())
|
||
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
|
||
modules = nn.Sequential(*modules)
|
||
|
||
elif cfg.projector_type == "downsample_mlp_gelu":
|
||
mlp_depth = cfg.depth
|
||
mlp_ratio = cfg.mlp_ratio
|
||
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
|
||
for _ in range(1, mlp_depth - 1):
|
||
modules.append(nn.GELU())
|
||
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
|
||
modules.append(nn.GELU())
|
||
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
|
||
modules = nn.Sequential(*modules)
|
||
|
||
else:
|
||
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
|
||
|
||
if cfg.token_pooling:
|
||
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
|
||
|
||
self.layers = modules
|
||
|
||
def forward(self, x):
|
||
if self.cfg.token_pooling:
|
||
batch_size, wxh, channels = x.shape
|
||
w = h = int(wxh ** 0.5)
|
||
x = x.view(batch_size, w, h, channels)
|
||
x = x.permute(0, 3, 1, 2)
|
||
# import ipdb; ipdb.set_trace()
|
||
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
|
||
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
|
||
# 在通道维度上拼接
|
||
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
|
||
|
||
# 通过线性层
|
||
patches = patches.permute(0, 2, 1, 3).contiguous()
|
||
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
|
||
|
||
x = self.token_pooling_layer(patches)
|
||
|
||
elif self.cfg.projector_type == 'downsample_mlp_gelu':
|
||
bs, hw, input_dim = x.shape
|
||
h = w = int((hw) ** 0.5)
|
||
|
||
"""compute padding"""
|
||
if h % self.cfg.downsample_ratio:
|
||
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
|
||
else:
|
||
pad = 0
|
||
x = x.reshape(bs, h, w, input_dim)
|
||
if pad > 0:
|
||
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
|
||
|
||
"""4 to 1 concat"""
|
||
x = x.permute(0, 3, 1, 2) # B, C, H, W
|
||
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio,
|
||
padding=0) # B, C*4, HW // 4
|
||
x = x.permute(0, 2, 1)
|
||
|
||
return self.layers(x)
|
||
|
||
|
||
class VisionEncoderConfig(PretrainedConfig):
|
||
model_type: str = "vision"
|
||
|
||
model_name: str = "siglip_large_patch16_384"
|
||
image_size: int = 384
|
||
patch_size: int = 16
|
||
width: int = 1024
|
||
layers: int = 24
|
||
heads: int = 16
|
||
mlp_ratio: int = 4
|
||
global_pool: str = "map"
|
||
ignore_head: bool = True
|
||
class_token: bool = False
|
||
num_classes: int = 0
|
||
use_checkpoint: bool = False
|
||
weight_init: str = "skip"
|
||
deterministic: bool = False
|
||
num_recomputing_layers: int = 0
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str = "siglip_large_patch16_384",
|
||
image_size: int = 384,
|
||
patch_size: int = 16,
|
||
width: int = 1024,
|
||
layers: int = 24,
|
||
heads: int = 16,
|
||
mlp_ratio: int = 4,
|
||
global_pool: str = "map",
|
||
ignore_head: bool = True,
|
||
class_token: bool = False,
|
||
num_classes: int = 0,
|
||
use_checkpoint: bool = False,
|
||
**kwargs
|
||
):
|
||
self.model_name = model_name
|
||
self.image_size = image_size
|
||
self.patch_size = patch_size
|
||
self.width = width
|
||
self.layers = layers
|
||
self.heads = heads
|
||
self.mlp_ratio = mlp_ratio
|
||
self.global_pool = global_pool
|
||
self.ignore_head = ignore_head
|
||
self.class_token = class_token
|
||
self.num_classes = num_classes
|
||
self.use_checkpoint = use_checkpoint
|
||
|
||
super().__init__(**kwargs)
|
||
|
||
|
||
class MlpProjectorConfig(PretrainedConfig):
|
||
model_type = "mlp_projector"
|
||
projector_type: str = "downsample_mlp_gelu"
|
||
input_dim: int = 1152
|
||
n_embed: int = 2048
|
||
depth: int = 2
|
||
mlp_ratio: int = 1
|
||
downsample_ratio: int = 2
|
||
token_pooling: bool = False
|
||
|
||
def __init__(
|
||
self,
|
||
projector_type: str = "downsample_mlp_gelu",
|
||
input_dim: int = 1152,
|
||
n_embed: int = 2048,
|
||
depth: int = 2,
|
||
mlp_ratio: int = 1,
|
||
downsample_ratio: int = 2,
|
||
**kwargs
|
||
):
|
||
self.projector_type = projector_type
|
||
self.input_dim = input_dim
|
||
self.n_embed = n_embed
|
||
self.depth = depth
|
||
self.mlp_ratio = mlp_ratio
|
||
self.downsample_ratio = downsample_ratio
|
||
|
||
super().__init__(**kwargs)
|
||
|
||
|
||
class DeepseekVLV2Config(PretrainedConfig):
|
||
model_type = "deepseek_vl_v2"
|
||
vision_config: VisionEncoderConfig
|
||
projector_config: MlpProjectorConfig
|
||
language_config: DeepseekV2Config
|
||
|
||
tile_tag: str = "2D"
|
||
global_view_pos: str = "head"
|
||
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)
|
||
|
||
def __init__(
|
||
self,
|
||
tile_tag: str = "tile_tag",
|
||
global_view_pos: str = "head",
|
||
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),
|
||
**kwargs
|
||
):
|
||
super().__init__(**kwargs)
|
||
|
||
vision_config = kwargs.get("vision_config", {})
|
||
self.vision_config = VisionEncoderConfig(**vision_config)
|
||
|
||
projector_config = kwargs.get("projector_config", {})
|
||
self.projector_config = MlpProjectorConfig(**projector_config)
|
||
|
||
language_config = kwargs.get("language_config", {})
|
||
if isinstance(language_config, DeepseekV2Config):
|
||
self.language_config = language_config
|
||
else:
|
||
self.language_config = DeepseekV2Config(**language_config)
|
||
|
||
self.tile_tag = tile_tag
|
||
self.global_view_pos = global_view_pos
|
||
self.candidate_resolutions = candidate_resolutions
|
||
|
||
|
||
class DeepseekVLV2PreTrainedModel(PreTrainedModel):
|
||
config_class = DeepseekVLV2Config
|
||
base_model_prefix = "deepseek_vl_v2"
|
||
_no_split_modules = []
|
||
_skip_keys_device_placement = "past_key_values"
|
||
|
||
|
||
class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
|
||
|
||
def __init__(self, config: DeepseekVLV2Config):
|
||
super().__init__(config)
|
||
|
||
# ----------- vision encoder ------------
|
||
vision_config = config.vision_config
|
||
self.vision = VisionTransformer(
|
||
img_size=vision_config.image_size,
|
||
patch_size=vision_config.patch_size,
|
||
embed_dim=vision_config.width,
|
||
depth=vision_config.layers,
|
||
num_heads=vision_config.heads,
|
||
mlp_ratio=vision_config.mlp_ratio,
|
||
class_token=vision_config.class_token,
|
||
global_pool=vision_config.global_pool,
|
||
ignore_head=vision_config.ignore_head,
|
||
weight_init=vision_config.weight_init,
|
||
num_classes=0,
|
||
deterministic=vision_config.deterministic,
|
||
num_recomputing_layers=vision_config.num_recomputing_layers
|
||
)
|
||
|
||
# ----------- vl projector ------------
|
||
projector_config = config.projector_config
|
||
self.projector = MlpProjector(projector_config)
|
||
|
||
# image token format 形式
|
||
# FIXME 目前tile tag & global_view_pos的默认取值都是之前的实验策略;后续应当去掉默认取值,改为没有取值就raise error
|
||
self.tile_tag = config.tile_tag
|
||
self.global_view_pos = config.global_view_pos
|
||
|
||
# 用于format image token sequence的特殊token
|
||
embed_std = 1 / torch.sqrt(torch.tensor(projector_config.n_embed, dtype=torch.float32))
|
||
if self.tile_tag == "2D":
|
||
# <|view_separator|>, <|\n|>
|
||
self.image_newline = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
|
||
# fix the typo: view_seperater
|
||
self.view_seperator = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
|
||
elif self.tile_tag == "1D":
|
||
# <|tile_x|>, <|tile_global|>
|
||
candidate_resolutions = config.candidate_resolutions
|
||
if len(candidate_resolutions) == 0:
|
||
raise ValueError(
|
||
f"len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}")
|
||
tile_variants_num = len(candidate_resolutions)
|
||
self.tile_indicators = nn.Parameter(
|
||
torch.randn(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std
|
||
)
|
||
else:
|
||
raise ValueError(f"tile tag should be either 1D or 2D, but got {self.tile_tag}")
|
||
|
||
# ----------- language model ------------
|
||
language_config = config.language_config
|
||
self.language = DeepseekV2ForCausalLM(language_config)
|
||
|
||
def prepare_inputs_embeds(
|
||
self,
|
||
input_ids: torch.LongTensor,
|
||
images: torch.FloatTensor,
|
||
images_seq_mask: torch.LongTensor,
|
||
images_spatial_crop: Optional[torch.LongTensor] = None,
|
||
**ignore_kwargs
|
||
):
|
||
"""
|
||
|
||
Args:
|
||
input_ids (torch.LongTensor): [b, T]
|
||
images (torch.FloatTensor): [b, max_n_images, 3, height, width]
|
||
images_seq_mask (torch.BoolTensor): [b, T]
|
||
images_spatial_crop (torch.LongTensor): [b, max_n_images, 2]
|
||
|
||
Returns:
|
||
input_embeds (torch.Tensor): [b, T, D]
|
||
"""
|
||
|
||
if images is None or images_spatial_crop.sum() == 0:
|
||
return self.language.get_input_embeddings()(input_ids)
|
||
|
||
bs, max_n_images, _ = images_spatial_crop.shape
|
||
batch_num_tiles = [0 for _ in range(bs)]
|
||
total_tiles = []
|
||
for idx in range(bs):
|
||
for jdx in range(max_n_images):
|
||
num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
|
||
if num_width_tiles == 0 or num_height_tiles == 0:
|
||
break
|
||
batch_num_tiles[idx] += (1 + num_width_tiles * num_height_tiles)
|
||
|
||
total_tiles.append(images[idx, :batch_num_tiles[idx]])
|
||
|
||
# [batch_all_tiles, 3, height, width]
|
||
total_tiles = torch.cat(total_tiles, dim=0)
|
||
assert total_tiles.shape[0] == sum(batch_num_tiles)
|
||
if total_tiles.shape[0] == 0:
|
||
return self.language.get_input_embeddings()(input_ids)
|
||
|
||
# [batch_all_tiles, vit_seq_len, c]
|
||
images_feature = self.vision(total_tiles)
|
||
|
||
# [batch_all_tiles, hw, D]
|
||
images_embeds = self.projector(images_feature)
|
||
_, hw, n_dim = images_embeds.shape
|
||
h = w = int(hw ** 0.5)
|
||
|
||
# put image tokens into the input_embeds, [b, T, D]
|
||
input_embeds = self.language.get_input_embeddings()(input_ids)
|
||
|
||
# 根据self.tile_tag & self.global_view_pos填充image token sequence
|
||
tile_index = 0
|
||
for idx in range(images_spatial_crop.shape[0]):
|
||
images_in_this_batch = []
|
||
for jdx in range(images_spatial_crop.shape[1]):
|
||
|
||
# extra global & local features
|
||
num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
|
||
if num_width_tiles == 0 or num_height_tiles == 0:
|
||
break
|
||
|
||
num_tiles_in_image = num_width_tiles * num_height_tiles
|
||
|
||
# [hw, D]
|
||
global_features = images_embeds[tile_index]
|
||
|
||
# [num_height_tiles * num_width_tiles, hw, D]
|
||
local_features = images_embeds[tile_index + 1: tile_index + 1 + num_tiles_in_image]
|
||
|
||
tile_index += num_tiles_in_image + 1
|
||
|
||
# format global and local features
|
||
if self.tile_tag == "2D":
|
||
|
||
# ----------------- global view add newline -----------------
|
||
# [hw, D] -> [h, w, D]
|
||
global_features = global_features.view(h, w, n_dim)
|
||
# [D] -> [h, 1, D]
|
||
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
||
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
||
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
||
# [h, w + 1, D] -> [h * (w + 1), D]
|
||
global_features = global_features.view(-1, n_dim)
|
||
|
||
# ----------------- local view add newline -----------------
|
||
# [num_height_tiles * num_width_tiles, h * w, D] -> [num_height_tiles * h, num_width_tiles * w, D]
|
||
local_features = rearrange(
|
||
local_features,
|
||
"(th tw) (h w) d -> (th h) (tw w) d",
|
||
th=num_height_tiles,
|
||
tw=num_width_tiles,
|
||
h=h,
|
||
w=w
|
||
)
|
||
|
||
# [D] -> [num_height_tiles * h, 1, D]
|
||
new_lines_in_local = repeat(
|
||
self.image_newline,
|
||
"d -> (th h) 1 d",
|
||
th=num_height_tiles,
|
||
h=h
|
||
)
|
||
|
||
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
||
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
||
|
||
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
||
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
||
local_features = local_features.view(-1, n_dim)
|
||
|
||
# ----------------- merge global and local tiles -----------------
|
||
if self.global_view_pos == "head":
|
||
global_local_features = torch.cat(
|
||
[global_features, self.view_seperator[None, :], local_features], dim=0)
|
||
else:
|
||
global_local_features = torch.cat(
|
||
[local_features, self.view_seperator[None, :], global_features], dim=0)
|
||
|
||
else:
|
||
# abandoned,实际上不会走这个逻辑
|
||
global_features = torch.cat(
|
||
[self.tile_indicators[0:1], global_features], dim=0
|
||
)
|
||
local_features = torch.cat(
|
||
[self.tile_indicators[1:num_tiles_in_image + 1].unsqueeze(1), local_features], dim=1
|
||
)
|
||
local_features = rearrange(local_features, 'crop_num hw d -> (crop_num hw) d')
|
||
|
||
if self.global_view_pos == "head":
|
||
global_local_features = torch.cat([global_features, local_features], dim=0)
|
||
else:
|
||
global_local_features = torch.cat([local_features, global_features], dim=0)
|
||
|
||
images_in_this_batch.append(global_local_features)
|
||
|
||
if len(images_in_this_batch) > 0:
|
||
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
||
input_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch)
|
||
|
||
return input_embeds
|
||
|
||
def generate(
|
||
self,
|
||
inputs: Optional[torch.Tensor] = None,
|
||
generation_config: Optional[GenerationConfig] = None,
|
||
logits_processor: Optional[LogitsProcessorList] = None,
|
||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||
synced_gpus: Optional[bool] = None,
|
||
assistant_model: Optional["PreTrainedModel"] = None,
|
||
streamer: Optional["BaseStreamer"] = None,
|
||
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||
**kwargs,
|
||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||
r"""
|
||
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
|
||
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. Beam-search decoding
|
||
is controlled by the `num_beams` parameter and the `num_return_sequences` parameter.
|
||
|
||
Parameters:
|
||
- `inputs` (optional) -- `torch.LongTensor` of shape `(batch, sequence_length)`:
|
||
The sequence used as a prompt for the generation. If `None`, generate for the model's prompt.
|
||
- `generation_config` (optional) -- `GenerationConfig`:
|
||
The generation config of the model.
|
||
- `logits_processor` (optional) -- `LogitsProcessorList`:
|
||
A list of instances of :class:`~transform
|
||
"""
|
||
|
||
return self.language.generate(
|
||
inputs=inputs,
|
||
generation_config=generation_config,
|
||
logits_processor=logits_processor,
|
||
stopping_criteria=stopping_criteria,
|
||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||
synced_gpus=synced_gpus,
|
||
assistant_model=assistant_model,
|
||
streamer=streamer,
|
||
negative_prompt_ids=negative_prompt_ids,
|
||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||
**kwargs,
|
||
)
|
||
|
||
|
||
AutoConfig.register("vision", VisionEncoderConfig)
|
||
AutoConfig.register("mlp_projector", MlpProjectorConfig)
|
||
AutoConfig.register("deepseek_vl_v2", DeepseekVLV2Config)
|
||
AutoModelForCausalLM.register(DeepseekVLV2Config, DeepseekVLV2ForCausalLM)
|