mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-23 06:08:59 -05:00
fix: Update modeling_vlm.py
This commit is contained in:
parent
45680ae127
commit
3e49061a00
@ -26,23 +26,21 @@ from transformers import (
|
|||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
|
PretrainedConfig,
|
||||||
)
|
)
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
|
|
||||||
from janus.models.clip_encoder import CLIPVisionTower
|
from janus.models.clip_encoder import CLIPVisionTower
|
||||||
from janus.models.projector import MlpProjector
|
from janus.models.projector import MlpProjector
|
||||||
|
|
||||||
|
|
||||||
class vision_head(torch.nn.Module):
|
class VisionHead(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Vision head module for processing visual embeddings.
|
||||||
|
"""
|
||||||
def __init__(self, params):
|
def __init__(self, params):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.output_mlp_projector = torch.nn.Linear(
|
self.output_mlp_projector = torch.nn.Linear(params.n_embed, params.image_token_embed)
|
||||||
params.n_embed, params.image_token_embed
|
|
||||||
)
|
|
||||||
self.vision_activation = torch.nn.GELU()
|
self.vision_activation = torch.nn.GELU()
|
||||||
self.vision_head = torch.nn.Linear(
|
self.vision_head = torch.nn.Linear(params.image_token_embed, params.image_token_size)
|
||||||
params.image_token_embed, params.image_token_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.output_mlp_projector(x)
|
x = self.output_mlp_projector(x)
|
||||||
@ -52,135 +50,91 @@ class vision_head(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def model_name_to_cls(cls_name):
|
def model_name_to_cls(cls_name):
|
||||||
if "MlpProjector" in cls_name:
|
"""
|
||||||
cls = MlpProjector
|
Maps a class name to its corresponding class.
|
||||||
|
"""
|
||||||
|
mapping = {
|
||||||
|
"MlpProjector": MlpProjector,
|
||||||
|
"CLIPVisionTower": CLIPVisionTower,
|
||||||
|
"vision_head": VisionHead,
|
||||||
|
}
|
||||||
|
|
||||||
elif "CLIPVisionTower" in cls_name:
|
if "VQ" in cls_name:
|
||||||
cls = CLIPVisionTower
|
|
||||||
|
|
||||||
elif "VQ" in cls_name:
|
|
||||||
from janus.models.vq_model import VQ_models
|
from janus.models.vq_model import VQ_models
|
||||||
|
return VQ_models[cls_name]
|
||||||
|
|
||||||
cls = VQ_models[cls_name]
|
cls = mapping.get(cls_name)
|
||||||
elif "vision_head" in cls_name:
|
if cls is None:
|
||||||
cls = vision_head
|
raise ValueError(f"Invalid class name: {cls_name}")
|
||||||
else:
|
|
||||||
raise ValueError(f"class_name {cls_name} is invalid.")
|
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
class VisionConfig(PretrainedConfig):
|
class BaseConfig(PretrainedConfig):
|
||||||
|
"""
|
||||||
|
Base configuration class for multi-modality components.
|
||||||
|
"""
|
||||||
|
model_type = ""
|
||||||
|
cls: str = ""
|
||||||
|
params: AttrDict = {}
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.cls = kwargs.get("cls", "")
|
||||||
|
if not isinstance(self.cls, str):
|
||||||
|
self.cls = self.cls.__name__
|
||||||
|
self.params = AttrDict(kwargs.get("params", {}))
|
||||||
|
|
||||||
|
|
||||||
|
class VisionConfig(BaseConfig):
|
||||||
model_type = "vision"
|
model_type = "vision"
|
||||||
cls: str = ""
|
|
||||||
params: AttrDict = {}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.cls = kwargs.get("cls", "")
|
|
||||||
if not isinstance(self.cls, str):
|
|
||||||
self.cls = self.cls.__name__
|
|
||||||
|
|
||||||
self.params = AttrDict(kwargs.get("params", {}))
|
|
||||||
|
|
||||||
|
|
||||||
class AlignerConfig(PretrainedConfig):
|
class AlignerConfig(BaseConfig):
|
||||||
model_type = "aligner"
|
model_type = "aligner"
|
||||||
cls: str = ""
|
|
||||||
params: AttrDict = {}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.cls = kwargs.get("cls", "")
|
|
||||||
if not isinstance(self.cls, str):
|
|
||||||
self.cls = self.cls.__name__
|
|
||||||
|
|
||||||
self.params = AttrDict(kwargs.get("params", {}))
|
|
||||||
|
|
||||||
|
|
||||||
class GenVisionConfig(PretrainedConfig):
|
class GenVisionConfig(BaseConfig):
|
||||||
model_type = "gen_vision"
|
model_type = "gen_vision"
|
||||||
cls: str = ""
|
|
||||||
params: AttrDict = {}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.cls = kwargs.get("cls", "")
|
|
||||||
if not isinstance(self.cls, str):
|
|
||||||
self.cls = self.cls.__name__
|
|
||||||
|
|
||||||
self.params = AttrDict(kwargs.get("params", {}))
|
|
||||||
|
|
||||||
|
|
||||||
class GenAlignerConfig(PretrainedConfig):
|
class GenAlignerConfig(BaseConfig):
|
||||||
model_type = "gen_aligner"
|
model_type = "gen_aligner"
|
||||||
cls: str = ""
|
|
||||||
params: AttrDict = {}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.cls = kwargs.get("cls", "")
|
|
||||||
if not isinstance(self.cls, str):
|
|
||||||
self.cls = self.cls.__name__
|
|
||||||
|
|
||||||
self.params = AttrDict(kwargs.get("params", {}))
|
|
||||||
|
|
||||||
|
|
||||||
class GenHeadConfig(PretrainedConfig):
|
class GenHeadConfig(BaseConfig):
|
||||||
model_type = "gen_head"
|
model_type = "gen_head"
|
||||||
cls: str = ""
|
|
||||||
params: AttrDict = {}
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
self.cls = kwargs.get("cls", "")
|
|
||||||
if not isinstance(self.cls, str):
|
|
||||||
self.cls = self.cls.__name__
|
|
||||||
|
|
||||||
self.params = AttrDict(kwargs.get("params", {}))
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalityConfig(PretrainedConfig):
|
class MultiModalityConfig(PretrainedConfig):
|
||||||
|
"""
|
||||||
|
Configuration for the multi-modality model.
|
||||||
|
"""
|
||||||
model_type = "multi_modality"
|
model_type = "multi_modality"
|
||||||
vision_config: VisionConfig
|
vision_config: VisionConfig
|
||||||
aligner_config: AlignerConfig
|
aligner_config: AlignerConfig
|
||||||
|
|
||||||
gen_vision_config: GenVisionConfig
|
gen_vision_config: GenVisionConfig
|
||||||
gen_aligner_config: GenAlignerConfig
|
gen_aligner_config: GenAlignerConfig
|
||||||
gen_head_config: GenHeadConfig
|
gen_head_config: GenHeadConfig
|
||||||
|
|
||||||
language_config: LlamaConfig
|
language_config: LlamaConfig
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
vision_config = kwargs.get("vision_config", {})
|
self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
|
||||||
self.vision_config = VisionConfig(**vision_config)
|
self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
|
||||||
|
self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
|
||||||
aligner_config = kwargs.get("aligner_config", {})
|
self.gen_aligner_config = GenAlignerConfig(**kwargs.get("gen_aligner_config", {}))
|
||||||
self.aligner_config = AlignerConfig(**aligner_config)
|
self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))
|
||||||
|
|
||||||
gen_vision_config = kwargs.get("gen_vision_config", {})
|
|
||||||
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
|
|
||||||
|
|
||||||
gen_aligner_config = kwargs.get("gen_aligner_config", {})
|
|
||||||
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
|
|
||||||
|
|
||||||
gen_head_config = kwargs.get("gen_head_config", {})
|
|
||||||
self.gen_head_config = GenHeadConfig(**gen_head_config)
|
|
||||||
|
|
||||||
language_config = kwargs.get("language_config", {})
|
language_config = kwargs.get("language_config", {})
|
||||||
if isinstance(language_config, LlamaConfig):
|
self.language_config = (
|
||||||
self.language_config = language_config
|
language_config if isinstance(language_config, LlamaConfig) else LlamaConfig(**language_config)
|
||||||
else:
|
)
|
||||||
self.language_config = LlamaConfig(**language_config)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalityPreTrainedModel(PreTrainedModel):
|
class MultiModalityPreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
Base class for multi-modality pre-trained models.
|
||||||
|
"""
|
||||||
config_class = MultiModalityConfig
|
config_class = MultiModalityConfig
|
||||||
base_model_prefix = "multi_modality"
|
base_model_prefix = "multi_modality"
|
||||||
_no_split_modules = []
|
_no_split_modules = []
|
||||||
@ -188,35 +142,40 @@ class MultiModalityPreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||||
|
"""
|
||||||
|
Multi-modality causal language model combining vision and language components.
|
||||||
|
"""
|
||||||
def __init__(self, config: MultiModalityConfig):
|
def __init__(self, config: MultiModalityConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
vision_config = config.vision_config
|
# Initialize vision model
|
||||||
vision_cls = model_name_to_cls(vision_config.cls)
|
vision_cls = model_name_to_cls(config.vision_config.cls)
|
||||||
self.vision_model = vision_cls(**vision_config.params)
|
self.vision_model = vision_cls(**config.vision_config.params)
|
||||||
|
|
||||||
aligner_config = config.aligner_config
|
# Initialize aligner
|
||||||
aligner_cls = model_name_to_cls(aligner_config.cls)
|
aligner_cls = model_name_to_cls(config.aligner_config.cls)
|
||||||
self.aligner = aligner_cls(aligner_config.params)
|
self.aligner = aligner_cls(config.aligner_config.params)
|
||||||
|
|
||||||
gen_vision_config = config.gen_vision_config
|
# Initialize generative vision model
|
||||||
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
|
gen_vision_cls = model_name_to_cls(config.gen_vision_config.cls)
|
||||||
self.gen_vision_model = gen_vision_cls()
|
self.gen_vision_model = gen_vision_cls()
|
||||||
|
|
||||||
gen_aligner_config = config.gen_aligner_config
|
# Initialize generative aligner
|
||||||
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
|
gen_aligner_cls = model_name_to_cls(config.gen_aligner_config.cls)
|
||||||
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
|
self.gen_aligner = gen_aligner_cls(config.gen_aligner_config.params)
|
||||||
|
|
||||||
gen_head_config = config.gen_head_config
|
# Initialize generative head
|
||||||
gen_head_cls = model_name_to_cls(gen_head_config.cls)
|
gen_head_cls = model_name_to_cls(config.gen_head_config.cls)
|
||||||
self.gen_head = gen_head_cls(gen_head_config.params)
|
self.gen_head = gen_head_cls(config.gen_head_config.params)
|
||||||
|
|
||||||
|
# Generative embedding layer
|
||||||
self.gen_embed = torch.nn.Embedding(
|
self.gen_embed = torch.nn.Embedding(
|
||||||
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
|
config.gen_vision_config.params.image_token_size,
|
||||||
|
config.gen_vision_config.params.n_embed,
|
||||||
)
|
)
|
||||||
|
|
||||||
language_config = config.language_config
|
# Language model
|
||||||
self.language_model = LlamaForCausalLM(language_config)
|
self.language_model = LlamaForCausalLM(config.language_config)
|
||||||
|
|
||||||
def prepare_inputs_embeds(
|
def prepare_inputs_embeds(
|
||||||
self,
|
self,
|
||||||
@ -227,46 +186,53 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
Prepares input embeddings by combining text and image embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_ids (torch.LongTensor): [b, T]
|
input_ids (torch.LongTensor): [b, T]
|
||||||
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
||||||
images_seq_mask (torch.BoolTensor): [b, T]
|
images_seq_mask (torch.BoolTensor): [b, T]
|
||||||
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
||||||
|
|
||||||
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
input_embeds (torch.Tensor): [b, T, D]
|
input_embeds (torch.Tensor): [b, T, D]
|
||||||
"""
|
"""
|
||||||
|
bs, n = pixel_values.shape[:2]
|
||||||
bs, n = pixel_values.shape[0:2]
|
|
||||||
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
||||||
# [b x n, T2, D]
|
|
||||||
images_embeds = self.aligner(self.vision_model(images))
|
|
||||||
|
|
||||||
# [b x n, T2, D] -> [b, n x T2, D]
|
# Process images through vision model and aligner
|
||||||
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D]
|
||||||
# [b, n, T2] -> [b, n x T2]
|
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # [b, n x T2, D]
|
||||||
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # [b, n x T2]
|
||||||
|
|
||||||
# [b, T, D]
|
# Prepare text embeddings
|
||||||
input_ids[input_ids < 0] = 0 # ignore the image embeddings
|
input_ids[input_ids < 0] = 0 # Ignore negative IDs
|
||||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # [b, T, D]
|
||||||
|
|
||||||
# replace with the image embeddings
|
# Replace text embeddings with image embeddings where applicable
|
||||||
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
||||||
|
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
||||||
|
"""
|
||||||
|
Prepares generative image embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_ids (torch.LongTensor): Image token IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Generated image embeddings.
|
||||||
|
"""
|
||||||
return self.gen_aligner(self.gen_embed(image_ids))
|
return self.gen_aligner(self.gen_embed(image_ids))
|
||||||
|
|
||||||
|
|
||||||
|
# Register configurations with Hugging Face's AutoConfig
|
||||||
AutoConfig.register("vision", VisionConfig)
|
AutoConfig.register("vision", VisionConfig)
|
||||||
AutoConfig.register("aligner", AlignerConfig)
|
AutoConfig.register("aligner", AlignerConfig)
|
||||||
AutoConfig.register("gen_vision", GenVisionConfig)
|
AutoConfig.register("gen_vision", GenVisionConfig)
|
||||||
AutoConfig.register("gen_aligner", GenAlignerConfig)
|
AutoConfig.register("gen_aligner", GenAlignerConfig)
|
||||||
AutoConfig.register("gen_head", GenHeadConfig)
|
AutoConfig.register("gen_head", GenHeadConfig)
|
||||||
AutoConfig.register("multi_modality", MultiModalityConfig)
|
AutoConfig.register("multi_modality", MultiModalityConfig)
|
||||||
|
|
||||||
|
# Register the multi-modality causal LM model
|
||||||
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|
||||||
|
Loading…
Reference in New Issue
Block a user