diff --git a/janus/models/modeling_vlm.py b/janus/models/modeling_vlm.py index 8d23bc0..146dae6 100644 --- a/janus/models/modeling_vlm.py +++ b/janus/models/modeling_vlm.py @@ -26,23 +26,21 @@ from transformers import ( LlamaConfig, LlamaForCausalLM, PreTrainedModel, + PretrainedConfig, ) -from transformers.configuration_utils import PretrainedConfig - from janus.models.clip_encoder import CLIPVisionTower from janus.models.projector import MlpProjector -class vision_head(torch.nn.Module): +class VisionHead(torch.nn.Module): + """ + Vision head module for processing visual embeddings. + """ def __init__(self, params): super().__init__() - self.output_mlp_projector = torch.nn.Linear( - params.n_embed, params.image_token_embed - ) + self.output_mlp_projector = torch.nn.Linear(params.n_embed, params.image_token_embed) self.vision_activation = torch.nn.GELU() - self.vision_head = torch.nn.Linear( - params.image_token_embed, params.image_token_size - ) + self.vision_head = torch.nn.Linear(params.image_token_embed, params.image_token_size) def forward(self, x): x = self.output_mlp_projector(x) @@ -52,135 +50,91 @@ class vision_head(torch.nn.Module): def model_name_to_cls(cls_name): - if "MlpProjector" in cls_name: - cls = MlpProjector + """ + Maps a class name to its corresponding class. + """ + mapping = { + "MlpProjector": MlpProjector, + "CLIPVisionTower": CLIPVisionTower, + "vision_head": VisionHead, + } - elif "CLIPVisionTower" in cls_name: - cls = CLIPVisionTower - - elif "VQ" in cls_name: + if "VQ" in cls_name: from janus.models.vq_model import VQ_models + return VQ_models[cls_name] - cls = VQ_models[cls_name] - elif "vision_head" in cls_name: - cls = vision_head - else: - raise ValueError(f"class_name {cls_name} is invalid.") - + cls = mapping.get(cls_name) + if cls is None: + raise ValueError(f"Invalid class name: {cls_name}") return cls -class VisionConfig(PretrainedConfig): +class BaseConfig(PretrainedConfig): + """ + Base configuration class for multi-modality components. + """ + model_type = "" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + self.params = AttrDict(kwargs.get("params", {})) + + +class VisionConfig(BaseConfig): model_type = "vision" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) -class AlignerConfig(PretrainedConfig): +class AlignerConfig(BaseConfig): model_type = "aligner" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) -class GenVisionConfig(PretrainedConfig): +class GenVisionConfig(BaseConfig): model_type = "gen_vision" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) -class GenAlignerConfig(PretrainedConfig): +class GenAlignerConfig(BaseConfig): model_type = "gen_aligner" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) -class GenHeadConfig(PretrainedConfig): +class GenHeadConfig(BaseConfig): model_type = "gen_head" - cls: str = "" - params: AttrDict = {} - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.cls = kwargs.get("cls", "") - if not isinstance(self.cls, str): - self.cls = self.cls.__name__ - - self.params = AttrDict(kwargs.get("params", {})) class MultiModalityConfig(PretrainedConfig): + """ + Configuration for the multi-modality model. + """ model_type = "multi_modality" vision_config: VisionConfig aligner_config: AlignerConfig - gen_vision_config: GenVisionConfig gen_aligner_config: GenAlignerConfig gen_head_config: GenHeadConfig - language_config: LlamaConfig def __init__(self, **kwargs): super().__init__(**kwargs) - vision_config = kwargs.get("vision_config", {}) - self.vision_config = VisionConfig(**vision_config) - - aligner_config = kwargs.get("aligner_config", {}) - self.aligner_config = AlignerConfig(**aligner_config) - - gen_vision_config = kwargs.get("gen_vision_config", {}) - self.gen_vision_config = GenVisionConfig(**gen_vision_config) - - gen_aligner_config = kwargs.get("gen_aligner_config", {}) - self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) - - gen_head_config = kwargs.get("gen_head_config", {}) - self.gen_head_config = GenHeadConfig(**gen_head_config) + self.vision_config = VisionConfig(**kwargs.get("vision_config", {})) + self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {})) + self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {})) + self.gen_aligner_config = GenAlignerConfig(**kwargs.get("gen_aligner_config", {})) + self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {})) language_config = kwargs.get("language_config", {}) - if isinstance(language_config, LlamaConfig): - self.language_config = language_config - else: - self.language_config = LlamaConfig(**language_config) + self.language_config = ( + language_config if isinstance(language_config, LlamaConfig) else LlamaConfig(**language_config) + ) class MultiModalityPreTrainedModel(PreTrainedModel): + """ + Base class for multi-modality pre-trained models. + """ config_class = MultiModalityConfig base_model_prefix = "multi_modality" _no_split_modules = [] @@ -188,35 +142,40 @@ class MultiModalityPreTrainedModel(PreTrainedModel): class MultiModalityCausalLM(MultiModalityPreTrainedModel): + """ + Multi-modality causal language model combining vision and language components. + """ def __init__(self, config: MultiModalityConfig): super().__init__(config) - vision_config = config.vision_config - vision_cls = model_name_to_cls(vision_config.cls) - self.vision_model = vision_cls(**vision_config.params) + # Initialize vision model + vision_cls = model_name_to_cls(config.vision_config.cls) + self.vision_model = vision_cls(**config.vision_config.params) - aligner_config = config.aligner_config - aligner_cls = model_name_to_cls(aligner_config.cls) - self.aligner = aligner_cls(aligner_config.params) + # Initialize aligner + aligner_cls = model_name_to_cls(config.aligner_config.cls) + self.aligner = aligner_cls(config.aligner_config.params) - gen_vision_config = config.gen_vision_config - gen_vision_cls = model_name_to_cls(gen_vision_config.cls) + # Initialize generative vision model + gen_vision_cls = model_name_to_cls(config.gen_vision_config.cls) self.gen_vision_model = gen_vision_cls() - gen_aligner_config = config.gen_aligner_config - gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls) - self.gen_aligner = gen_aligner_cls(gen_aligner_config.params) + # Initialize generative aligner + gen_aligner_cls = model_name_to_cls(config.gen_aligner_config.cls) + self.gen_aligner = gen_aligner_cls(config.gen_aligner_config.params) - gen_head_config = config.gen_head_config - gen_head_cls = model_name_to_cls(gen_head_config.cls) - self.gen_head = gen_head_cls(gen_head_config.params) + # Initialize generative head + gen_head_cls = model_name_to_cls(config.gen_head_config.cls) + self.gen_head = gen_head_cls(config.gen_head_config.params) + # Generative embedding layer self.gen_embed = torch.nn.Embedding( - gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed + config.gen_vision_config.params.image_token_size, + config.gen_vision_config.params.n_embed, ) - language_config = config.language_config - self.language_model = LlamaForCausalLM(language_config) + # Language model + self.language_model = LlamaForCausalLM(config.language_config) def prepare_inputs_embeds( self, @@ -227,46 +186,53 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): **kwargs, ): """ + Prepares input embeddings by combining text and image embeddings. Args: input_ids (torch.LongTensor): [b, T] - pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] + pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] images_seq_mask (torch.BoolTensor): [b, T] images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] - assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) - Returns: input_embeds (torch.Tensor): [b, T, D] """ - - bs, n = pixel_values.shape[0:2] + bs, n = pixel_values.shape[:2] images = rearrange(pixel_values, "b n c h w -> (b n) c h w") - # [b x n, T2, D] - images_embeds = self.aligner(self.vision_model(images)) - # [b x n, T2, D] -> [b, n x T2, D] - images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) - # [b, n, T2] -> [b, n x T2] - images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") + # Process images through vision model and aligner + images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D] + images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # [b, n x T2, D] + images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # [b, n x T2] - # [b, T, D] - input_ids[input_ids < 0] = 0 # ignore the image embeddings - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + # Prepare text embeddings + input_ids[input_ids < 0] = 0 # Ignore negative IDs + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # [b, T, D] - # replace with the image embeddings + # Replace text embeddings with image embeddings where applicable inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] - return inputs_embeds def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): + """ + Prepares generative image embeddings. + + Args: + image_ids (torch.LongTensor): Image token IDs. + + Returns: + torch.Tensor: Generated image embeddings. + """ return self.gen_aligner(self.gen_embed(image_ids)) +# Register configurations with Hugging Face's AutoConfig AutoConfig.register("vision", VisionConfig) AutoConfig.register("aligner", AlignerConfig) AutoConfig.register("gen_vision", GenVisionConfig) AutoConfig.register("gen_aligner", GenAlignerConfig) AutoConfig.register("gen_head", GenHeadConfig) AutoConfig.register("multi_modality", MultiModalityConfig) + +# Register the multi-modality causal LM model AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)