fix: Update modeling_vlm.py

This commit is contained in:
aarsxx 2025-02-11 15:23:04 +07:00 committed by GitHub
parent 45680ae127
commit 3e49061a00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,23 +26,21 @@ from transformers import (
LlamaConfig, LlamaConfig,
LlamaForCausalLM, LlamaForCausalLM,
PreTrainedModel, PreTrainedModel,
PretrainedConfig,
) )
from transformers.configuration_utils import PretrainedConfig
from janus.models.clip_encoder import CLIPVisionTower from janus.models.clip_encoder import CLIPVisionTower
from janus.models.projector import MlpProjector from janus.models.projector import MlpProjector
class vision_head(torch.nn.Module): class VisionHead(torch.nn.Module):
"""
Vision head module for processing visual embeddings.
"""
def __init__(self, params): def __init__(self, params):
super().__init__() super().__init__()
self.output_mlp_projector = torch.nn.Linear( self.output_mlp_projector = torch.nn.Linear(params.n_embed, params.image_token_embed)
params.n_embed, params.image_token_embed
)
self.vision_activation = torch.nn.GELU() self.vision_activation = torch.nn.GELU()
self.vision_head = torch.nn.Linear( self.vision_head = torch.nn.Linear(params.image_token_embed, params.image_token_size)
params.image_token_embed, params.image_token_size
)
def forward(self, x): def forward(self, x):
x = self.output_mlp_projector(x) x = self.output_mlp_projector(x)
@ -52,135 +50,91 @@ class vision_head(torch.nn.Module):
def model_name_to_cls(cls_name): def model_name_to_cls(cls_name):
if "MlpProjector" in cls_name: """
cls = MlpProjector Maps a class name to its corresponding class.
"""
mapping = {
"MlpProjector": MlpProjector,
"CLIPVisionTower": CLIPVisionTower,
"vision_head": VisionHead,
}
elif "CLIPVisionTower" in cls_name: if "VQ" in cls_name:
cls = CLIPVisionTower
elif "VQ" in cls_name:
from janus.models.vq_model import VQ_models from janus.models.vq_model import VQ_models
return VQ_models[cls_name]
cls = VQ_models[cls_name] cls = mapping.get(cls_name)
elif "vision_head" in cls_name: if cls is None:
cls = vision_head raise ValueError(f"Invalid class name: {cls_name}")
else:
raise ValueError(f"class_name {cls_name} is invalid.")
return cls return cls
class VisionConfig(PretrainedConfig): class BaseConfig(PretrainedConfig):
"""
Base configuration class for multi-modality components.
"""
model_type = ""
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class VisionConfig(BaseConfig):
model_type = "vision" model_type = "vision"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class AlignerConfig(PretrainedConfig): class AlignerConfig(BaseConfig):
model_type = "aligner" model_type = "aligner"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenVisionConfig(PretrainedConfig): class GenVisionConfig(BaseConfig):
model_type = "gen_vision" model_type = "gen_vision"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenAlignerConfig(PretrainedConfig): class GenAlignerConfig(BaseConfig):
model_type = "gen_aligner" model_type = "gen_aligner"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenHeadConfig(PretrainedConfig): class GenHeadConfig(BaseConfig):
model_type = "gen_head" model_type = "gen_head"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class MultiModalityConfig(PretrainedConfig): class MultiModalityConfig(PretrainedConfig):
"""
Configuration for the multi-modality model.
"""
model_type = "multi_modality" model_type = "multi_modality"
vision_config: VisionConfig vision_config: VisionConfig
aligner_config: AlignerConfig aligner_config: AlignerConfig
gen_vision_config: GenVisionConfig gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig gen_head_config: GenHeadConfig
language_config: LlamaConfig language_config: LlamaConfig
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {}) self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
self.vision_config = VisionConfig(**vision_config) self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
aligner_config = kwargs.get("aligner_config", {}) self.gen_aligner_config = GenAlignerConfig(**kwargs.get("gen_aligner_config", {}))
self.aligner_config = AlignerConfig(**aligner_config) self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))
gen_vision_config = kwargs.get("gen_vision_config", {})
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
gen_aligner_config = kwargs.get("gen_aligner_config", {})
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
gen_head_config = kwargs.get("gen_head_config", {})
self.gen_head_config = GenHeadConfig(**gen_head_config)
language_config = kwargs.get("language_config", {}) language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig): self.language_config = (
self.language_config = language_config language_config if isinstance(language_config, LlamaConfig) else LlamaConfig(**language_config)
else: )
self.language_config = LlamaConfig(**language_config)
class MultiModalityPreTrainedModel(PreTrainedModel): class MultiModalityPreTrainedModel(PreTrainedModel):
"""
Base class for multi-modality pre-trained models.
"""
config_class = MultiModalityConfig config_class = MultiModalityConfig
base_model_prefix = "multi_modality" base_model_prefix = "multi_modality"
_no_split_modules = [] _no_split_modules = []
@ -188,35 +142,40 @@ class MultiModalityPreTrainedModel(PreTrainedModel):
class MultiModalityCausalLM(MultiModalityPreTrainedModel): class MultiModalityCausalLM(MultiModalityPreTrainedModel):
"""
Multi-modality causal language model combining vision and language components.
"""
def __init__(self, config: MultiModalityConfig): def __init__(self, config: MultiModalityConfig):
super().__init__(config) super().__init__(config)
vision_config = config.vision_config # Initialize vision model
vision_cls = model_name_to_cls(vision_config.cls) vision_cls = model_name_to_cls(config.vision_config.cls)
self.vision_model = vision_cls(**vision_config.params) self.vision_model = vision_cls(**config.vision_config.params)
aligner_config = config.aligner_config # Initialize aligner
aligner_cls = model_name_to_cls(aligner_config.cls) aligner_cls = model_name_to_cls(config.aligner_config.cls)
self.aligner = aligner_cls(aligner_config.params) self.aligner = aligner_cls(config.aligner_config.params)
gen_vision_config = config.gen_vision_config # Initialize generative vision model
gen_vision_cls = model_name_to_cls(gen_vision_config.cls) gen_vision_cls = model_name_to_cls(config.gen_vision_config.cls)
self.gen_vision_model = gen_vision_cls() self.gen_vision_model = gen_vision_cls()
gen_aligner_config = config.gen_aligner_config # Initialize generative aligner
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) gen_aligner_cls = model_name_to_cls(config.gen_aligner_config.cls)
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) self.gen_aligner = gen_aligner_cls(config.gen_aligner_config.params)
gen_head_config = config.gen_head_config # Initialize generative head
gen_head_cls = model_name_to_cls(gen_head_config.cls) gen_head_cls = model_name_to_cls(config.gen_head_config.cls)
self.gen_head = gen_head_cls(gen_head_config.params) self.gen_head = gen_head_cls(config.gen_head_config.params)
# Generative embedding layer
self.gen_embed = torch.nn.Embedding( self.gen_embed = torch.nn.Embedding(
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed config.gen_vision_config.params.image_token_size,
config.gen_vision_config.params.n_embed,
) )
language_config = config.language_config # Language model
self.language_model = LlamaForCausalLM(language_config) self.language_model = LlamaForCausalLM(config.language_config)
def prepare_inputs_embeds( def prepare_inputs_embeds(
self, self,
@ -227,6 +186,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
**kwargs, **kwargs,
): ):
""" """
Prepares input embeddings by combining text and image embeddings.
Args: Args:
input_ids (torch.LongTensor): [b, T] input_ids (torch.LongTensor): [b, T]
@ -234,39 +194,45 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
images_seq_mask (torch.BoolTensor): [b, T] images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns: Returns:
input_embeds (torch.Tensor): [b, T, D] input_embeds (torch.Tensor): [b, T, D]
""" """
bs, n = pixel_values.shape[:2]
bs, n = pixel_values.shape[0:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w") images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
# [b x n, T2, D] -> [b, n x T2, D] # Process images through vision model and aligner
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D]
# [b, n, T2] -> [b, n x T2] images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # [b, n x T2, D]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # [b, n x T2]
# [b, T, D] # Prepare text embeddings
input_ids[input_ids < 0] = 0 # ignore the image embeddings input_ids[input_ids < 0] = 0 # Ignore negative IDs
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # [b, T, D]
# replace with the image embeddings # Replace text embeddings with image embeddings where applicable
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds return inputs_embeds
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
"""
Prepares generative image embeddings.
Args:
image_ids (torch.LongTensor): Image token IDs.
Returns:
torch.Tensor: Generated image embeddings.
"""
return self.gen_aligner(self.gen_embed(image_ids)) return self.gen_aligner(self.gen_embed(image_ids))
# Register configurations with Hugging Face's AutoConfig
AutoConfig.register("vision", VisionConfig) AutoConfig.register("vision", VisionConfig)
AutoConfig.register("aligner", AlignerConfig) AutoConfig.register("aligner", AlignerConfig)
AutoConfig.register("gen_vision", GenVisionConfig) AutoConfig.register("gen_vision", GenVisionConfig)
AutoConfig.register("gen_aligner", GenAlignerConfig) AutoConfig.register("gen_aligner", GenAlignerConfig)
AutoConfig.register("gen_head", GenHeadConfig) AutoConfig.register("gen_head", GenHeadConfig)
AutoConfig.register("multi_modality", MultiModalityConfig) AutoConfig.register("multi_modality", MultiModalityConfig)
# Register the multi-modality causal LM model
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)