From e23df8feccb82a242bcd1304234328c81c027307 Mon Sep 17 00:00:00 2001 From: Aaditya Sharma <163049815+AadiSharma49@users.noreply.github.com> Date: Tue, 28 Jan 2025 22:53:23 +0530 Subject: [PATCH] Update modeling_vlm.py --- janus/models/modeling_vlm.py | 209 ++++++++++++++--------------------- 1 file changed, 80 insertions(+), 129 deletions(-) diff --git a/janus/models/modeling_vlm.py b/janus/models/modeling_vlm.py index 8d23bc0..8f52195 100644 --- a/janus/models/modeling_vlm.py +++ b/janus/models/modeling_vlm.py @@ -33,8 +33,8 @@ from janus.models.clip_encoder import CLIPVisionTower from janus.models.projector import MlpProjector -class vision_head(torch.nn.Module): - def __init__(self, params): +class VisionHead(torch.nn.Module): + def __init__(self, params: AttrDict): super().__init__() self.output_mlp_projector = torch.nn.Linear( params.n_embed, params.image_token_embed @@ -44,135 +44,82 @@ class vision_head(torch.nn.Module): params.image_token_embed, params.image_token_size ) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.output_mlp_projector(x) x = self.vision_activation(x) x = self.vision_head(x) return x -def model_name_to_cls(cls_name): +def model_name_to_cls(cls_name: str) -> type: if "MlpProjector" in cls_name: cls = MlpProjector - elif "CLIPVisionTower" in cls_name: cls = CLIPVisionTower - elif "VQ" in cls_name: from janus.models.vq_model import VQ_models cls = VQ_models[cls_name] elif "vision_head" in cls_name: - cls = vision_head + # Maintain backward compatibility with existing configs using "vision_head" + cls = VisionHead else: - raise ValueError(f"class_name {cls_name} is invalid.") + raise ValueError(f"Invalid class name: {cls_name}") return cls -class VisionConfig(PretrainedConfig): +class BaseSubConfig(PretrainedConfig): + model_type: str = "base" + cls: str = "" + params: AttrDict = AttrDict() + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + self.params = AttrDict(kwargs.get("params", {})) + + +class VisionConfig(BaseSubConfig): model_type = "vision" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) -class AlignerConfig(PretrainedConfig): +class AlignerConfig(BaseSubConfig): model_type = "aligner" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) -class GenVisionConfig(PretrainedConfig): +class GenVisionConfig(BaseSubConfig): model_type = "gen_vision" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) -class GenAlignerConfig(PretrainedConfig): +class GenAlignerConfig(BaseSubConfig): model_type = "gen_aligner" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) -class GenHeadConfig(PretrainedConfig): +class GenHeadConfig(BaseSubConfig): model_type = "gen_head" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) class MultiModalityConfig(PretrainedConfig): model_type = "multi_modality" vision_config: VisionConfig aligner_config: AlignerConfig - gen_vision_config: GenVisionConfig gen_aligner_config: GenAlignerConfig gen_head_config: GenHeadConfig - language_config: LlamaConfig def __init__(self, **kwargs): super().__init__(**kwargs) - vision_config = kwargs.get("vision_config", {}) - self.vision_config = VisionConfig(**vision_config) - - aligner_config = kwargs.get("aligner_config", {}) - self.aligner_config = AlignerConfig(**aligner_config) - - gen_vision_config = kwargs.get("gen_vision_config", {}) - self.gen_vision_config = GenVisionConfig(**gen_vision_config) - - gen_aligner_config = kwargs.get("gen_aligner_config", {}) - self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) - - gen_head_config = kwargs.get("gen_head_config", {}) - self.gen_head_config = GenHeadConfig(**gen_head_config) - + self.vision_config = VisionConfig(**kwargs.get("vision_config", {})) + self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {})) + self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {})) + self.gen_aligner_config = GenAlignerConfig( + **kwargs.get("gen_aligner_config", {}) + ) + self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {})) language_config = kwargs.get("language_config", {}) if isinstance(language_config, LlamaConfig): self.language_config = language_config @@ -191,32 +138,31 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): def __init__(self, config: MultiModalityConfig): super().__init__(config) - vision_config = config.vision_config - vision_cls = model_name_to_cls(vision_config.cls) - self.vision_model = vision_cls(**vision_config.params) - - aligner_config = config.aligner_config - aligner_cls = model_name_to_cls(aligner_config.cls) - self.aligner = aligner_cls(aligner_config.params) - - gen_vision_config = config.gen_vision_config - gen_vision_cls = model_name_to_cls(gen_vision_config.cls) - self.gen_vision_model = gen_vision_cls() - - gen_aligner_config = config.gen_aligner_config - gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) - self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) - - gen_head_config = config.gen_head_config - gen_head_cls = model_name_to_cls(gen_head_config.cls) - self.gen_head = gen_head_cls(gen_head_config.params) - - self.gen_embed = torch.nn.Embedding( - gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed + # Initialize vision components + self.vision_model = model_name_to_cls(config.vision_config.cls)( + **config.vision_config.params + ) + self.aligner = model_name_to_cls(config.aligner_config.cls)( + config.aligner_config.params ) - language_config = config.language_config - self.language_model = LlamaForCausalLM(language_config) + # Initialize generation components + self.gen_vision_model = model_name_to_cls(config.gen_vision_config.cls)( + **config.gen_vision_config.params + ) + self.gen_aligner = model_name_to_cls(config.gen_aligner_config.cls)( + config.gen_aligner_config.params + ) + self.gen_head = model_name_to_cls(config.gen_head_config.cls)( + config.gen_head_config.params + ) + + # Initialize embeddings and language model + self.gen_embed = torch.nn.Embedding( + config.gen_vision_config.params.image_token_size, + config.gen_vision_config.params.n_embed, + ) + self.language_model = LlamaForCausalLM(config.language_config) def prepare_inputs_embeds( self, @@ -225,44 +171,49 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): images_seq_mask: torch.LongTensor, images_emb_mask: torch.LongTensor, **kwargs, - ): + ) -> torch.Tensor: """ + Prepares combined text and image embeddings for the language model. Args: - input_ids (torch.LongTensor): [b, T] - pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] - images_seq_mask (torch.BoolTensor): [b, T] - images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] - - assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) + input_ids: Token IDs for text inputs, shape [batch_size, seq_len] + pixel_values: Image tensors, shape [batch_size, num_images, channels, height, width] + images_seq_mask: Boolean mask indicating image positions in the text sequence, + shape [batch_size, seq_len] + images_emb_mask: Boolean mask for valid image tokens per image, + shape [batch_size, num_images, tokens_per_image] Returns: - input_embeds (torch.Tensor): [b, T, D] + Combined embeddings tensor of shape [batch_size, seq_len, embedding_dim] """ - - bs, n = pixel_values.shape[0:2] + batch_size, num_images = pixel_values.shape[:2] images = rearrange(pixel_values, "b n c h w -> (b n) c h w") - # [b x n, T2, D] - images_embeds = self.aligner(self.vision_model(images)) + images_embeds = self.aligner(self.vision_model(images)) # [(b n), tokens, dim] - # [b x n, T2, D] -> [b, n x T2, D] - images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) - # [b, n, T2] -> [b, n x T2] + # Reshape embeddings and masks + images_embeds = rearrange( + images_embeds, "(b n) t d -> b (n t) d", b=batch_size, n=num_images + ) images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") - # [b, T, D] - input_ids[input_ids < 0] = 0 # ignore the image embeddings - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + # Validate mask compatibility + assert torch.all( + images_seq_mask.sum(dim=1) == images_emb_mask.sum(dim=1) + ), "Masks must have matching number of image tokens" - # replace with the image embeddings + # Get text embeddings and replace image positions + input_ids = input_ids.masked_fill(input_ids < 0, 0) # Replace negatives + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] return inputs_embeds - def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): + def prepare_gen_img_embeds(self, image_ids: torch.LongTensor) -> torch.Tensor: + """Generates image embeddings from token IDs for generation.""" return self.gen_aligner(self.gen_embed(image_ids)) +# Configuration registration AutoConfig.register("vision", VisionConfig) AutoConfig.register("aligner", AlignerConfig) AutoConfig.register("gen_vision", GenVisionConfig)