mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-23 14:18:58 -05:00
224 lines
8.0 KiB
Python
224 lines
8.0 KiB
Python
# Copyright (c) 2023-2024 DeepSeek.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
# this software and associated documentation files (the "Software"), to deal in
|
|
# the Software without restriction, including without limitation the rights to
|
|
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
# the Software, and to permit persons to whom the Software is furnished to do so,
|
|
# subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
|
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
|
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
|
|
import torch
|
|
from attrdict import AttrDict
|
|
from einops import rearrange
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoModelForCausalLM,
|
|
LlamaConfig,
|
|
LlamaForCausalLM,
|
|
PreTrainedModel,
|
|
)
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
from janus.models.clip_encoder import CLIPVisionTower
|
|
from janus.models.projector import MlpProjector
|
|
|
|
|
|
class VisionHead(torch.nn.Module):
|
|
def __init__(self, params: AttrDict):
|
|
super().__init__()
|
|
self.output_mlp_projector = torch.nn.Linear(
|
|
params.n_embed, params.image_token_embed
|
|
)
|
|
self.vision_activation = torch.nn.GELU()
|
|
self.vision_head = torch.nn.Linear(
|
|
params.image_token_embed, params.image_token_size
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.output_mlp_projector(x)
|
|
x = self.vision_activation(x)
|
|
x = self.vision_head(x)
|
|
return x
|
|
|
|
|
|
def model_name_to_cls(cls_name: str) -> type:
|
|
if "MlpProjector" in cls_name:
|
|
cls = MlpProjector
|
|
elif "CLIPVisionTower" in cls_name:
|
|
cls = CLIPVisionTower
|
|
elif "VQ" in cls_name:
|
|
from janus.models.vq_model import VQ_models
|
|
|
|
cls = VQ_models[cls_name]
|
|
elif "vision_head" in cls_name:
|
|
# Maintain backward compatibility with existing configs using "vision_head"
|
|
cls = VisionHead
|
|
else:
|
|
raise ValueError(f"Invalid class name: {cls_name}")
|
|
|
|
return cls
|
|
|
|
|
|
class BaseSubConfig(PretrainedConfig):
|
|
model_type: str = "base"
|
|
cls: str = ""
|
|
params: AttrDict = AttrDict()
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.cls = kwargs.get("cls", "")
|
|
if not isinstance(self.cls, str):
|
|
self.cls = self.cls.__name__
|
|
self.params = AttrDict(kwargs.get("params", {}))
|
|
|
|
|
|
class VisionConfig(BaseSubConfig):
|
|
model_type = "vision"
|
|
|
|
|
|
class AlignerConfig(BaseSubConfig):
|
|
model_type = "aligner"
|
|
|
|
|
|
class GenVisionConfig(BaseSubConfig):
|
|
model_type = "gen_vision"
|
|
|
|
|
|
class GenAlignerConfig(BaseSubConfig):
|
|
model_type = "gen_aligner"
|
|
|
|
|
|
class GenHeadConfig(BaseSubConfig):
|
|
model_type = "gen_head"
|
|
|
|
|
|
class MultiModalityConfig(PretrainedConfig):
|
|
model_type = "multi_modality"
|
|
vision_config: VisionConfig
|
|
aligner_config: AlignerConfig
|
|
gen_vision_config: GenVisionConfig
|
|
gen_aligner_config: GenAlignerConfig
|
|
gen_head_config: GenHeadConfig
|
|
language_config: LlamaConfig
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
|
|
self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
|
|
self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
|
|
self.gen_aligner_config = GenAlignerConfig(
|
|
**kwargs.get("gen_aligner_config", {})
|
|
)
|
|
self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))
|
|
language_config = kwargs.get("language_config", {})
|
|
if isinstance(language_config, LlamaConfig):
|
|
self.language_config = language_config
|
|
else:
|
|
self.language_config = LlamaConfig(**language_config)
|
|
|
|
|
|
class MultiModalityPreTrainedModel(PreTrainedModel):
|
|
config_class = MultiModalityConfig
|
|
base_model_prefix = "multi_modality"
|
|
_no_split_modules = []
|
|
_skip_keys_device_placement = "past_key_values"
|
|
|
|
|
|
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
def __init__(self, config: MultiModalityConfig):
|
|
super().__init__(config)
|
|
|
|
# Initialize vision components
|
|
self.vision_model = model_name_to_cls(config.vision_config.cls)(
|
|
**config.vision_config.params
|
|
)
|
|
self.aligner = model_name_to_cls(config.aligner_config.cls)(
|
|
config.aligner_config.params
|
|
)
|
|
|
|
# Initialize generation components
|
|
self.gen_vision_model = model_name_to_cls(config.gen_vision_config.cls)(
|
|
**config.gen_vision_config.params
|
|
)
|
|
self.gen_aligner = model_name_to_cls(config.gen_aligner_config.cls)(
|
|
config.gen_aligner_config.params
|
|
)
|
|
self.gen_head = model_name_to_cls(config.gen_head_config.cls)(
|
|
config.gen_head_config.params
|
|
)
|
|
|
|
# Initialize embeddings and language model
|
|
self.gen_embed = torch.nn.Embedding(
|
|
config.gen_vision_config.params.image_token_size,
|
|
config.gen_vision_config.params.n_embed,
|
|
)
|
|
self.language_model = LlamaForCausalLM(config.language_config)
|
|
|
|
def prepare_inputs_embeds(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
pixel_values: torch.FloatTensor,
|
|
images_seq_mask: torch.LongTensor,
|
|
images_emb_mask: torch.LongTensor,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Prepares combined text and image embeddings for the language model.
|
|
|
|
Args:
|
|
input_ids: Token IDs for text inputs, shape [batch_size, seq_len]
|
|
pixel_values: Image tensors, shape [batch_size, num_images, channels, height, width]
|
|
images_seq_mask: Boolean mask indicating image positions in the text sequence,
|
|
shape [batch_size, seq_len]
|
|
images_emb_mask: Boolean mask for valid image tokens per image,
|
|
shape [batch_size, num_images, tokens_per_image]
|
|
|
|
Returns:
|
|
Combined embeddings tensor of shape [batch_size, seq_len, embedding_dim]
|
|
"""
|
|
batch_size, num_images = pixel_values.shape[:2]
|
|
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
|
images_embeds = self.aligner(self.vision_model(images)) # [(b n), tokens, dim]
|
|
|
|
# Reshape embeddings and masks
|
|
images_embeds = rearrange(
|
|
images_embeds, "(b n) t d -> b (n t) d", b=batch_size, n=num_images
|
|
)
|
|
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
|
|
|
# Validate mask compatibility
|
|
assert torch.all(
|
|
images_seq_mask.sum(dim=1) == images_emb_mask.sum(dim=1)
|
|
), "Masks must have matching number of image tokens"
|
|
|
|
# Get text embeddings and replace image positions
|
|
input_ids = input_ids.masked_fill(input_ids < 0, 0) # Replace negatives
|
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
|
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
|
|
|
return inputs_embeds
|
|
|
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor) -> torch.Tensor:
|
|
"""Generates image embeddings from token IDs for generation."""
|
|
return self.gen_aligner(self.gen_embed(image_ids))
|
|
|
|
|
|
# Configuration registration
|
|
AutoConfig.register("vision", VisionConfig)
|
|
AutoConfig.register("aligner", AlignerConfig)
|
|
AutoConfig.register("gen_vision", GenVisionConfig)
|
|
AutoConfig.register("gen_aligner", GenAlignerConfig)
|
|
AutoConfig.register("gen_head", GenHeadConfig)
|
|
AutoConfig.register("multi_modality", MultiModalityConfig)
|
|
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|