Update modeling_vlm.py

This commit is contained in:
Aaditya Sharma 2025-01-28 22:53:23 +05:30 committed by GitHub
parent a74a59f8a9
commit e23df8fecc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,8 +33,8 @@ from janus.models.clip_encoder import CLIPVisionTower
from janus.models.projector import MlpProjector
class vision_head(torch.nn.Module):
def __init__(self, params):
class VisionHead(torch.nn.Module):
def __init__(self, params: AttrDict):
super().__init__()
self.output_mlp_projector = torch.nn.Linear(
params.n_embed, params.image_token_embed
@ -44,135 +44,82 @@ class vision_head(torch.nn.Module):
params.image_token_embed, params.image_token_size
)
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.output_mlp_projector(x)
x = self.vision_activation(x)
x = self.vision_head(x)
return x
def model_name_to_cls(cls_name):
def model_name_to_cls(cls_name: str) -> type:
if "MlpProjector" in cls_name:
cls = MlpProjector
elif "CLIPVisionTower" in cls_name:
cls = CLIPVisionTower
elif "VQ" in cls_name:
from janus.models.vq_model import VQ_models
cls = VQ_models[cls_name]
elif "vision_head" in cls_name:
cls = vision_head
# Maintain backward compatibility with existing configs using "vision_head"
cls = VisionHead
else:
raise ValueError(f"class_name {cls_name} is invalid.")
raise ValueError(f"Invalid class name: {cls_name}")
return cls
class VisionConfig(PretrainedConfig):
class BaseSubConfig(PretrainedConfig):
model_type: str = "base"
cls: str = ""
params: AttrDict = AttrDict()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class VisionConfig(BaseSubConfig):
model_type = "vision"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class AlignerConfig(PretrainedConfig):
class AlignerConfig(BaseSubConfig):
model_type = "aligner"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenVisionConfig(PretrainedConfig):
class GenVisionConfig(BaseSubConfig):
model_type = "gen_vision"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenAlignerConfig(PretrainedConfig):
class GenAlignerConfig(BaseSubConfig):
model_type = "gen_aligner"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenHeadConfig(PretrainedConfig):
class GenHeadConfig(BaseSubConfig):
model_type = "gen_head"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig
gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig
language_config: LlamaConfig
def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionConfig(**vision_config)
aligner_config = kwargs.get("aligner_config", {})
self.aligner_config = AlignerConfig(**aligner_config)
gen_vision_config = kwargs.get("gen_vision_config", {})
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
gen_aligner_config = kwargs.get("gen_aligner_config", {})
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
gen_head_config = kwargs.get("gen_head_config", {})
self.gen_head_config = GenHeadConfig(**gen_head_config)
self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
self.gen_aligner_config = GenAlignerConfig(
**kwargs.get("gen_aligner_config", {})
)
self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))
language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
@ -191,32 +138,31 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def __init__(self, config: MultiModalityConfig):
super().__init__(config)
vision_config = config.vision_config
vision_cls = model_name_to_cls(vision_config.cls)
self.vision_model = vision_cls(**vision_config.params)
aligner_config = config.aligner_config
aligner_cls = model_name_to_cls(aligner_config.cls)
self.aligner = aligner_cls(aligner_config.params)
gen_vision_config = config.gen_vision_config
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
self.gen_vision_model = gen_vision_cls()
gen_aligner_config = config.gen_aligner_config
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
gen_head_config = config.gen_head_config
gen_head_cls = model_name_to_cls(gen_head_config.cls)
self.gen_head = gen_head_cls(gen_head_config.params)
self.gen_embed = torch.nn.Embedding(
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
# Initialize vision components
self.vision_model = model_name_to_cls(config.vision_config.cls)(
**config.vision_config.params
)
self.aligner = model_name_to_cls(config.aligner_config.cls)(
config.aligner_config.params
)
language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config)
# Initialize generation components
self.gen_vision_model = model_name_to_cls(config.gen_vision_config.cls)(
**config.gen_vision_config.params
)
self.gen_aligner = model_name_to_cls(config.gen_aligner_config.cls)(
config.gen_aligner_config.params
)
self.gen_head = model_name_to_cls(config.gen_head_config.cls)(
config.gen_head_config.params
)
# Initialize embeddings and language model
self.gen_embed = torch.nn.Embedding(
config.gen_vision_config.params.image_token_size,
config.gen_vision_config.params.n_embed,
)
self.language_model = LlamaForCausalLM(config.language_config)
def prepare_inputs_embeds(
self,
@ -225,44 +171,49 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.LongTensor,
**kwargs,
):
) -> torch.Tensor:
"""
Prepares combined text and image embeddings for the language model.
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
input_ids: Token IDs for text inputs, shape [batch_size, seq_len]
pixel_values: Image tensors, shape [batch_size, num_images, channels, height, width]
images_seq_mask: Boolean mask indicating image positions in the text sequence,
shape [batch_size, seq_len]
images_emb_mask: Boolean mask for valid image tokens per image,
shape [batch_size, num_images, tokens_per_image]
Returns:
input_embeds (torch.Tensor): [b, T, D]
Combined embeddings tensor of shape [batch_size, seq_len, embedding_dim]
"""
bs, n = pixel_values.shape[0:2]
batch_size, num_images = pixel_values.shape[:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
images_embeds = self.aligner(self.vision_model(images)) # [(b n), tokens, dim]
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
# Reshape embeddings and masks
images_embeds = rearrange(
images_embeds, "(b n) t d -> b (n t) d", b=batch_size, n=num_images
)
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
input_ids[input_ids < 0] = 0 # ignore the image embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
# Validate mask compatibility
assert torch.all(
images_seq_mask.sum(dim=1) == images_emb_mask.sum(dim=1)
), "Masks must have matching number of image tokens"
# replace with the image embeddings
# Get text embeddings and replace image positions
input_ids = input_ids.masked_fill(input_ids < 0, 0) # Replace negatives
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor) -> torch.Tensor:
"""Generates image embeddings from token IDs for generation."""
return self.gen_aligner(self.gen_embed(image_ids))
# Configuration registration
AutoConfig.register("vision", VisionConfig)
AutoConfig.register("aligner", AlignerConfig)
AutoConfig.register("gen_vision", GenVisionConfig)