diff --git a/janus/models/modeling_vlm.py b/janus/models/modeling_vlm.py
index 8d23bc0..8f52195 100644
--- a/janus/models/modeling_vlm.py
+++ b/janus/models/modeling_vlm.py
@@ -33,8 +33,8 @@ from janus.models.clip_encoder import CLIPVisionTower
 from janus.models.projector import MlpProjector
 
 
-class vision_head(torch.nn.Module):
-    def __init__(self, params):
+class VisionHead(torch.nn.Module):
+    def __init__(self, params: AttrDict):
         super().__init__()
         self.output_mlp_projector = torch.nn.Linear(
             params.n_embed, params.image_token_embed
@@ -44,135 +44,82 @@ class vision_head(torch.nn.Module):
             params.image_token_embed, params.image_token_size
         )
 
-    def forward(self, x):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
         x = self.output_mlp_projector(x)
         x = self.vision_activation(x)
         x = self.vision_head(x)
         return x
 
 
-def model_name_to_cls(cls_name):
+def model_name_to_cls(cls_name: str) -> type:
     if "MlpProjector" in cls_name:
         cls = MlpProjector
-
     elif "CLIPVisionTower" in cls_name:
         cls = CLIPVisionTower
-
     elif "VQ" in cls_name:
         from janus.models.vq_model import VQ_models
 
         cls = VQ_models[cls_name]
     elif "vision_head" in cls_name:
-        cls = vision_head
+        # Maintain backward compatibility with existing configs using "vision_head"
+        cls = VisionHead
     else:
-        raise ValueError(f"class_name {cls_name} is invalid.")
+        raise ValueError(f"Invalid class name: {cls_name}")
 
     return cls
 
 
-class VisionConfig(PretrainedConfig):
+class BaseSubConfig(PretrainedConfig):
+    model_type: str = "base"
+    cls: str = ""
+    params: AttrDict = AttrDict()
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.cls = kwargs.get("cls", "")
+        if not isinstance(self.cls, str):
+            self.cls = self.cls.__name__
+        self.params = AttrDict(kwargs.get("params", {}))
+
+
+class VisionConfig(BaseSubConfig):
     model_type = "vision"
-    cls: str = ""
-    params: AttrDict = {}
-
-    def __init__(self, **kwargs):
-        super().__init__(**kwargs)
-
-        self.cls = kwargs.get("cls", "")
-        if not isinstance(self.cls, str):
-            self.cls = self.cls.__name__
-
-        self.params = AttrDict(kwargs.get("params", {}))
 
 
-class AlignerConfig(PretrainedConfig):
+class AlignerConfig(BaseSubConfig):
     model_type = "aligner"
-    cls: str = ""
-    params: AttrDict = {}
-
-    def __init__(self, **kwargs):
-        super().__init__(**kwargs)
-
-        self.cls = kwargs.get("cls", "")
-        if not isinstance(self.cls, str):
-            self.cls = self.cls.__name__
-
-        self.params = AttrDict(kwargs.get("params", {}))
 
 
-class GenVisionConfig(PretrainedConfig):
+class GenVisionConfig(BaseSubConfig):
     model_type = "gen_vision"
-    cls: str = ""
-    params: AttrDict = {}
-
-    def __init__(self, **kwargs):
-        super().__init__(**kwargs)
-
-        self.cls = kwargs.get("cls", "")
-        if not isinstance(self.cls, str):
-            self.cls = self.cls.__name__
-
-        self.params = AttrDict(kwargs.get("params", {}))
 
 
-class GenAlignerConfig(PretrainedConfig):
+class GenAlignerConfig(BaseSubConfig):
     model_type = "gen_aligner"
-    cls: str = ""
-    params: AttrDict = {}
-
-    def __init__(self, **kwargs):
-        super().__init__(**kwargs)
-
-        self.cls = kwargs.get("cls", "")
-        if not isinstance(self.cls, str):
-            self.cls = self.cls.__name__
-
-        self.params = AttrDict(kwargs.get("params", {}))
 
 
-class GenHeadConfig(PretrainedConfig):
+class GenHeadConfig(BaseSubConfig):
     model_type = "gen_head"
-    cls: str = ""
-    params: AttrDict = {}
-
-    def __init__(self, **kwargs):
-        super().__init__(**kwargs)
-
-        self.cls = kwargs.get("cls", "")
-        if not isinstance(self.cls, str):
-            self.cls = self.cls.__name__
-
-        self.params = AttrDict(kwargs.get("params", {}))
 
 
 class MultiModalityConfig(PretrainedConfig):
     model_type = "multi_modality"
     vision_config: VisionConfig
     aligner_config: AlignerConfig
-
     gen_vision_config: GenVisionConfig
     gen_aligner_config: GenAlignerConfig
     gen_head_config: GenHeadConfig
-
     language_config: LlamaConfig
 
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
-        vision_config = kwargs.get("vision_config", {})
-        self.vision_config = VisionConfig(**vision_config)
-
-        aligner_config = kwargs.get("aligner_config", {})
-        self.aligner_config = AlignerConfig(**aligner_config)
-
-        gen_vision_config = kwargs.get("gen_vision_config", {})
-        self.gen_vision_config = GenVisionConfig(**gen_vision_config)
-
-        gen_aligner_config = kwargs.get("gen_aligner_config", {})
-        self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
-
-        gen_head_config = kwargs.get("gen_head_config", {})
-        self.gen_head_config = GenHeadConfig(**gen_head_config)
-
+        self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
+        self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
+        self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
+        self.gen_aligner_config = GenAlignerConfig(
+            **kwargs.get("gen_aligner_config", {})
+        )
+        self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))
         language_config = kwargs.get("language_config", {})
         if isinstance(language_config, LlamaConfig):
             self.language_config = language_config
@@ -191,32 +138,31 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
     def __init__(self, config: MultiModalityConfig):
         super().__init__(config)
 
-        vision_config = config.vision_config
-        vision_cls = model_name_to_cls(vision_config.cls)
-        self.vision_model = vision_cls(**vision_config.params)
-
-        aligner_config = config.aligner_config
-        aligner_cls = model_name_to_cls(aligner_config.cls)
-        self.aligner = aligner_cls(aligner_config.params)
-
-        gen_vision_config = config.gen_vision_config
-        gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
-        self.gen_vision_model = gen_vision_cls()
-
-        gen_aligner_config = config.gen_aligner_config
-        gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
-        self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
-
-        gen_head_config = config.gen_head_config
-        gen_head_cls = model_name_to_cls(gen_head_config.cls)
-        self.gen_head = gen_head_cls(gen_head_config.params)
-
-        self.gen_embed = torch.nn.Embedding(
-            gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
+        # Initialize vision components
+        self.vision_model = model_name_to_cls(config.vision_config.cls)(
+            **config.vision_config.params
+        )
+        self.aligner = model_name_to_cls(config.aligner_config.cls)(
+            config.aligner_config.params
         )
 
-        language_config = config.language_config
-        self.language_model = LlamaForCausalLM(language_config)
+        # Initialize generation components
+        self.gen_vision_model = model_name_to_cls(config.gen_vision_config.cls)(
+            **config.gen_vision_config.params
+        )
+        self.gen_aligner = model_name_to_cls(config.gen_aligner_config.cls)(
+            config.gen_aligner_config.params
+        )
+        self.gen_head = model_name_to_cls(config.gen_head_config.cls)(
+            config.gen_head_config.params
+        )
+
+        # Initialize embeddings and language model
+        self.gen_embed = torch.nn.Embedding(
+            config.gen_vision_config.params.image_token_size,
+            config.gen_vision_config.params.n_embed,
+        )
+        self.language_model = LlamaForCausalLM(config.language_config)
 
     def prepare_inputs_embeds(
         self,
@@ -225,44 +171,49 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
         images_seq_mask: torch.LongTensor,
         images_emb_mask: torch.LongTensor,
         **kwargs,
-    ):
+    ) -> torch.Tensor:
         """
+        Prepares combined text and image embeddings for the language model.
 
         Args:
-            input_ids (torch.LongTensor): [b, T]
-            pixel_values (torch.FloatTensor):   [b, n_images, 3, h, w]
-            images_seq_mask (torch.BoolTensor): [b, T]
-            images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
-
-            assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
+            input_ids: Token IDs for text inputs, shape [batch_size, seq_len]
+            pixel_values: Image tensors, shape [batch_size, num_images, channels, height, width]
+            images_seq_mask: Boolean mask indicating image positions in the text sequence,
+                shape [batch_size, seq_len]
+            images_emb_mask: Boolean mask for valid image tokens per image,
+                shape [batch_size, num_images, tokens_per_image]
 
         Returns:
-            input_embeds (torch.Tensor): [b, T, D]
+            Combined embeddings tensor of shape [batch_size, seq_len, embedding_dim]
         """
-
-        bs, n = pixel_values.shape[0:2]
+        batch_size, num_images = pixel_values.shape[:2]
         images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
-        # [b x n, T2, D]
-        images_embeds = self.aligner(self.vision_model(images))
+        images_embeds = self.aligner(self.vision_model(images))  # [(b n), tokens, dim]
 
-        # [b x n, T2, D] -> [b, n x T2, D]
-        images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
-        # [b, n, T2] -> [b, n x T2]
+        # Reshape embeddings and masks
+        images_embeds = rearrange(
+            images_embeds, "(b n) t d -> b (n t) d", b=batch_size, n=num_images
+        )
         images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
 
-        # [b, T, D]
-        input_ids[input_ids < 0] = 0  # ignore the image embeddings
-        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+        # Validate mask compatibility
+        assert torch.all(
+            images_seq_mask.sum(dim=1) == images_emb_mask.sum(dim=1)
+        ), "Masks must have matching number of image tokens"
 
-        # replace with the image embeddings
+        # Get text embeddings and replace image positions
+        input_ids = input_ids.masked_fill(input_ids < 0, 0)  # Replace negatives
+        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
         inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
 
         return inputs_embeds
 
-    def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
+    def prepare_gen_img_embeds(self, image_ids: torch.LongTensor) -> torch.Tensor:
+        """Generates image embeddings from token IDs for generation."""
         return self.gen_aligner(self.gen_embed(image_ids))
 
 
+# Configuration registration
 AutoConfig.register("vision", VisionConfig)
 AutoConfig.register("aligner", AlignerConfig)
 AutoConfig.register("gen_vision", GenVisionConfig)