DeepSeek-VL2/deepseek_vl/models/modeling_deepseek_vl_v2.py
2024-12-13 20:38:59 +08:00

473 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from attrdict import AttrDict
from einops import rearrange, repeat
from typing import Optional, List, Tuple, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.configuration_utils import PretrainedConfig
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PreTrainedModel, GenerationConfig, LogitsProcessorList, StoppingCriteriaList,
)
from transformers.generation.utils import GenerateOutput
from .siglip_vit import VisionTransformer
from .configuration_deepseek import DeepseekV2Config
from .modeling_deepseek import DeepseekV2ForCausalLM
class MlpProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
if cfg.projector_type == "identity":
modules = nn.Identity()
elif cfg.projector_type == "linear":
modules = nn.Linear(cfg.input_dim, cfg.n_embed)
elif cfg.projector_type == "mlp_gelu":
mlp_depth = cfg.depth
modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
modules = nn.Sequential(*modules)
elif cfg.projector_type == "downsample_mlp_gelu":
mlp_depth = cfg.depth
mlp_ratio = cfg.mlp_ratio
modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
for _ in range(1, mlp_depth - 1):
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
modules.append(nn.GELU())
modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
modules = nn.Sequential(*modules)
else:
raise ValueError(f"Unknown projector type: {cfg.projector_type}")
if cfg.token_pooling:
self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
self.layers = modules
def forward(self, x):
if self.cfg.token_pooling:
batch_size, wxh, channels = x.shape
w = h = int(wxh ** 0.5)
x = x.view(batch_size, w, h, channels)
x = x.permute(0, 3, 1, 2)
# import ipdb; ipdb.set_trace()
patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
batch_size, channels, h_patches, w_patches, _, _ = patches.size()
# 在通道维度上拼接
patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
# 通过线性层
patches = patches.permute(0, 2, 1, 3).contiguous()
patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
x = self.token_pooling_layer(patches)
elif self.cfg.projector_type == 'downsample_mlp_gelu':
bs, hw, input_dim = x.shape
h = w = int((hw) ** 0.5)
"""compute padding"""
if h % self.cfg.downsample_ratio:
pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
else:
pad = 0
x = x.reshape(bs, h, w, input_dim)
if pad > 0:
x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
"""4 to 1 concat"""
x = x.permute(0, 3, 1, 2) # B, C, H, W
x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio,
padding=0) # B, C*4, HW // 4
x = x.permute(0, 2, 1)
return self.layers(x)
class VisionEncoderConfig(PretrainedConfig):
model_type: str = "vision"
model_name: str = "siglip_large_patch16_384"
image_size: int = 384
patch_size: int = 16
width: int = 1024
layers: int = 24
heads: int = 16
mlp_ratio: int = 4
global_pool: str = "map"
ignore_head: bool = True
class_token: bool = False
num_classes: int = 0
use_checkpoint: bool = False
weight_init: str = "skip"
deterministic: bool = False
num_recomputing_layers: int = 0
def __init__(
self,
model_name: str = "siglip_large_patch16_384",
image_size: int = 384,
patch_size: int = 16,
width: int = 1024,
layers: int = 24,
heads: int = 16,
mlp_ratio: int = 4,
global_pool: str = "map",
ignore_head: bool = True,
class_token: bool = False,
num_classes: int = 0,
use_checkpoint: bool = False,
**kwargs
):
self.model_name = model_name
self.image_size = image_size
self.patch_size = patch_size
self.width = width
self.layers = layers
self.heads = heads
self.mlp_ratio = mlp_ratio
self.global_pool = global_pool
self.ignore_head = ignore_head
self.class_token = class_token
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
super().__init__(**kwargs)
class MlpProjectorConfig(PretrainedConfig):
model_type = "mlp_projector"
projector_type: str = "downsample_mlp_gelu"
input_dim: int = 1152
n_embed: int = 2048
depth: int = 2
mlp_ratio: int = 1
downsample_ratio: int = 2
token_pooling: bool = False
def __init__(
self,
projector_type: str = "downsample_mlp_gelu",
input_dim: int = 1152,
n_embed: int = 2048,
depth: int = 2,
mlp_ratio: int = 1,
downsample_ratio: int = 2,
**kwargs
):
self.projector_type = projector_type
self.input_dim = input_dim
self.n_embed = n_embed
self.depth = depth
self.mlp_ratio = mlp_ratio
self.downsample_ratio = downsample_ratio
super().__init__(**kwargs)
class DeepseekVLV2Config(PretrainedConfig):
model_type = "deepseek_vl_v2"
vision_config: VisionEncoderConfig
projector_config: MlpProjectorConfig
language_config: DeepseekV2Config
tile_tag: str = "2D"
global_view_pos: str = "head"
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),)
def __init__(
self,
tile_tag: str = "tile_tag",
global_view_pos: str = "head",
candidate_resolutions: Tuple[Tuple[int, int]] = ((384, 384),),
**kwargs
):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionEncoderConfig(**vision_config)
projector_config = kwargs.get("projector_config", {})
self.projector_config = MlpProjectorConfig(**projector_config)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, DeepseekV2Config):
self.language_config = language_config
else:
self.language_config = DeepseekV2Config(**language_config)
self.tile_tag = tile_tag
self.global_view_pos = global_view_pos
self.candidate_resolutions = candidate_resolutions
class DeepseekVLV2PreTrainedModel(PreTrainedModel):
config_class = DeepseekVLV2Config
base_model_prefix = "deepseek_vl_v2"
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
class DeepseekVLV2ForCausalLM(DeepseekVLV2PreTrainedModel):
def __init__(self, config: DeepseekVLV2Config):
super().__init__(config)
# ----------- vision encoder ------------
vision_config = config.vision_config
self.vision = VisionTransformer(
img_size=vision_config.image_size,
patch_size=vision_config.patch_size,
embed_dim=vision_config.width,
depth=vision_config.layers,
num_heads=vision_config.heads,
mlp_ratio=vision_config.mlp_ratio,
class_token=vision_config.class_token,
global_pool=vision_config.global_pool,
ignore_head=vision_config.ignore_head,
weight_init=vision_config.weight_init,
num_classes=0,
deterministic=vision_config.deterministic,
num_recomputing_layers=vision_config.num_recomputing_layers
)
# ----------- vl projector ------------
projector_config = config.projector_config
self.projector = MlpProjector(projector_config)
# image token format 形式
# FIXME 目前tile tag & global_view_pos的默认取值都是之前的实验策略后续应当去掉默认取值改为没有取值就raise error
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# 用于format image token sequence的特殊token
embed_std = 1 / torch.sqrt(torch.tensor(projector_config.n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
# fix the typo: view_seperater
self.view_seperator = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
elif self.tile_tag == "1D":
# <|tile_x|>, <|tile_global|>
candidate_resolutions = config.candidate_resolutions
if len(candidate_resolutions) == 0:
raise ValueError(
f"len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}")
tile_variants_num = len(candidate_resolutions)
self.tile_indicators = nn.Parameter(
torch.randn(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std
)
else:
raise ValueError(f"tile tag should be either 1D or 2D, but got {self.tile_tag}")
# ----------- language model ------------
language_config = config.language_config
self.language = DeepseekV2ForCausalLM(language_config)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
images: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_spatial_crop: Optional[torch.LongTensor] = None,
**ignore_kwargs
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
images (torch.FloatTensor): [b, max_n_images, 3, height, width]
images_seq_mask (torch.BoolTensor): [b, T]
images_spatial_crop (torch.LongTensor): [b, max_n_images, 2]
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
if images is None or images_spatial_crop.sum() == 0:
return self.language.get_input_embeddings()(input_ids)
bs, max_n_images, _ = images_spatial_crop.shape
batch_num_tiles = [0 for _ in range(bs)]
total_tiles = []
for idx in range(bs):
for jdx in range(max_n_images):
num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
batch_num_tiles[idx] += (1 + num_width_tiles * num_height_tiles)
total_tiles.append(images[idx, :batch_num_tiles[idx]])
# [batch_all_tiles, 3, height, width]
total_tiles = torch.cat(total_tiles, dim=0)
assert total_tiles.shape[0] == sum(batch_num_tiles)
if total_tiles.shape[0] == 0:
return self.language.get_input_embeddings()(input_ids)
# [batch_all_tiles, vit_seq_len, c]
images_feature = self.vision(total_tiles)
# [batch_all_tiles, hw, D]
images_embeds = self.projector(images_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw ** 0.5)
# put image tokens into the input_embeds, [b, T, D]
input_embeds = self.language.get_input_embeddings()(input_ids)
# 根据self.tile_tag & self.global_view_pos填充image token sequence
tile_index = 0
for idx in range(images_spatial_crop.shape[0]):
images_in_this_batch = []
for jdx in range(images_spatial_crop.shape[1]):
# extra global & local features
num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[tile_index + 1: tile_index + 1 + num_tiles_in_image]
tile_index += num_tiles_in_image + 1
# format global and local features
if self.tile_tag == "2D":
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] -> [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w
)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# ----------------- merge global and local tiles -----------------
if self.global_view_pos == "head":
global_local_features = torch.cat(
[global_features, self.view_seperator[None, :], local_features], dim=0)
else:
global_local_features = torch.cat(
[local_features, self.view_seperator[None, :], global_features], dim=0)
else:
# abandoned实际上不会走这个逻辑
global_features = torch.cat(
[self.tile_indicators[0:1], global_features], dim=0
)
local_features = torch.cat(
[self.tile_indicators[1:num_tiles_in_image + 1].unsqueeze(1), local_features], dim=1
)
local_features = rearrange(local_features, 'crop_num hw d -> (crop_num hw) d')
if self.global_view_pos == "head":
global_local_features = torch.cat([global_features, local_features], dim=0)
else:
global_local_features = torch.cat([local_features, global_features], dim=0)
images_in_this_batch.append(global_local_features)
if len(images_in_this_batch) > 0:
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
input_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch)
return input_embeds
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. Beam-search decoding
is controlled by the `num_beams` parameter and the `num_return_sequences` parameter.
Parameters:
- `inputs` (optional) -- `torch.LongTensor` of shape `(batch, sequence_length)`:
The sequence used as a prompt for the generation. If `None`, generate for the model's prompt.
- `generation_config` (optional) -- `GenerationConfig`:
The generation config of the model.
- `logits_processor` (optional) -- `LogitsProcessorList`:
A list of instances of :class:`~transform
"""
return self.language.generate(
inputs=inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
**kwargs,
)
AutoConfig.register("vision", VisionEncoderConfig)
AutoConfig.register("mlp_projector", MlpProjectorConfig)
AutoConfig.register("deepseek_vl_v2", DeepseekVLV2Config)
AutoModelForCausalLM.register(DeepseekVLV2Config, DeepseekVLV2ForCausalLM)