diff --git a/README.md b/README.md
index 1a5817e..3c06748 100755
--- a/README.md
+++ b/README.md
@@ -17,7 +17,6 @@
-
@@ -66,7 +65,7 @@
**2024.11.13**: JanusFlow is released, a new unified model with rectified flow for image generation. See [paper](https://arxiv.org/abs/2411.07975), [demo](https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B) and [usage](https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janusflow).
-**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link]( https://github.com/open-compass/VLMEvalKit/pull/541).
+**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link](https://github.com/open-compass/VLMEvalKit/pull/541).
**2024.10.20**: (1) Fix a bug in [tokenizer_config.json](https://huggingface.co/deepseek-ai/Janus-1.3B/blob/main/tokenizer_config.json). The previous version caused classifier-free guidance to not function properly, resulting in relatively poor visual generation quality. (2) Release Gradio demo ([online demo](https://huggingface.co/spaces/deepseek-ai/Janus-1.3B) and [local](#gradio-demo)).
@@ -165,10 +164,10 @@ prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(vl_gpt.device)
-# # run image encoder to get the image embeddings
+# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
-# # run the model to get the response
+# run the model to get the response
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
diff --git a/janus/models/modeling_vlm.py b/janus/models/modeling_vlm.py
index 8d23bc0..146dae6 100644
--- a/janus/models/modeling_vlm.py
+++ b/janus/models/modeling_vlm.py
@@ -26,23 +26,21 @@ from transformers import (
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel,
+ PretrainedConfig,
)
-from transformers.configuration_utils import PretrainedConfig
-
from janus.models.clip_encoder import CLIPVisionTower
from janus.models.projector import MlpProjector
-class vision_head(torch.nn.Module):
+class VisionHead(torch.nn.Module):
+ """
+ Vision head module for processing visual embeddings.
+ """
def __init__(self, params):
super().__init__()
- self.output_mlp_projector = torch.nn.Linear(
- params.n_embed, params.image_token_embed
- )
+ self.output_mlp_projector = torch.nn.Linear(params.n_embed, params.image_token_embed)
self.vision_activation = torch.nn.GELU()
- self.vision_head = torch.nn.Linear(
- params.image_token_embed, params.image_token_size
- )
+ self.vision_head = torch.nn.Linear(params.image_token_embed, params.image_token_size)
def forward(self, x):
x = self.output_mlp_projector(x)
@@ -52,135 +50,91 @@ class vision_head(torch.nn.Module):
def model_name_to_cls(cls_name):
- if "MlpProjector" in cls_name:
- cls = MlpProjector
+ """
+ Maps a class name to its corresponding class.
+ """
+ mapping = {
+ "MlpProjector": MlpProjector,
+ "CLIPVisionTower": CLIPVisionTower,
+ "vision_head": VisionHead,
+ }
- elif "CLIPVisionTower" in cls_name:
- cls = CLIPVisionTower
-
- elif "VQ" in cls_name:
+ if "VQ" in cls_name:
from janus.models.vq_model import VQ_models
+ return VQ_models[cls_name]
- cls = VQ_models[cls_name]
- elif "vision_head" in cls_name:
- cls = vision_head
- else:
- raise ValueError(f"class_name {cls_name} is invalid.")
-
+ cls = mapping.get(cls_name)
+ if cls is None:
+ raise ValueError(f"Invalid class name: {cls_name}")
return cls
-class VisionConfig(PretrainedConfig):
+class BaseConfig(PretrainedConfig):
+ """
+ Base configuration class for multi-modality components.
+ """
+ model_type = ""
+ cls: str = ""
+ params: AttrDict = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.cls = kwargs.get("cls", "")
+ if not isinstance(self.cls, str):
+ self.cls = self.cls.__name__
+ self.params = AttrDict(kwargs.get("params", {}))
+
+
+class VisionConfig(BaseConfig):
model_type = "vision"
- cls: str = ""
- params: AttrDict = {}
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- self.cls = kwargs.get("cls", "")
- if not isinstance(self.cls, str):
- self.cls = self.cls.__name__
-
- self.params = AttrDict(kwargs.get("params", {}))
-class AlignerConfig(PretrainedConfig):
+class AlignerConfig(BaseConfig):
model_type = "aligner"
- cls: str = ""
- params: AttrDict = {}
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- self.cls = kwargs.get("cls", "")
- if not isinstance(self.cls, str):
- self.cls = self.cls.__name__
-
- self.params = AttrDict(kwargs.get("params", {}))
-class GenVisionConfig(PretrainedConfig):
+class GenVisionConfig(BaseConfig):
model_type = "gen_vision"
- cls: str = ""
- params: AttrDict = {}
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- self.cls = kwargs.get("cls", "")
- if not isinstance(self.cls, str):
- self.cls = self.cls.__name__
-
- self.params = AttrDict(kwargs.get("params", {}))
-class GenAlignerConfig(PretrainedConfig):
+class GenAlignerConfig(BaseConfig):
model_type = "gen_aligner"
- cls: str = ""
- params: AttrDict = {}
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- self.cls = kwargs.get("cls", "")
- if not isinstance(self.cls, str):
- self.cls = self.cls.__name__
-
- self.params = AttrDict(kwargs.get("params", {}))
-class GenHeadConfig(PretrainedConfig):
+class GenHeadConfig(BaseConfig):
model_type = "gen_head"
- cls: str = ""
- params: AttrDict = {}
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- self.cls = kwargs.get("cls", "")
- if not isinstance(self.cls, str):
- self.cls = self.cls.__name__
-
- self.params = AttrDict(kwargs.get("params", {}))
class MultiModalityConfig(PretrainedConfig):
+ """
+ Configuration for the multi-modality model.
+ """
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig
-
gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig
-
language_config: LlamaConfig
def __init__(self, **kwargs):
super().__init__(**kwargs)
- vision_config = kwargs.get("vision_config", {})
- self.vision_config = VisionConfig(**vision_config)
-
- aligner_config = kwargs.get("aligner_config", {})
- self.aligner_config = AlignerConfig(**aligner_config)
-
- gen_vision_config = kwargs.get("gen_vision_config", {})
- self.gen_vision_config = GenVisionConfig(**gen_vision_config)
-
- gen_aligner_config = kwargs.get("gen_aligner_config", {})
- self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
-
- gen_head_config = kwargs.get("gen_head_config", {})
- self.gen_head_config = GenHeadConfig(**gen_head_config)
+ self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
+ self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
+ self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
+ self.gen_aligner_config = GenAlignerConfig(**kwargs.get("gen_aligner_config", {}))
+ self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))
language_config = kwargs.get("language_config", {})
- if isinstance(language_config, LlamaConfig):
- self.language_config = language_config
- else:
- self.language_config = LlamaConfig(**language_config)
+ self.language_config = (
+ language_config if isinstance(language_config, LlamaConfig) else LlamaConfig(**language_config)
+ )
class MultiModalityPreTrainedModel(PreTrainedModel):
+ """
+ Base class for multi-modality pre-trained models.
+ """
config_class = MultiModalityConfig
base_model_prefix = "multi_modality"
_no_split_modules = []
@@ -188,35 +142,40 @@ class MultiModalityPreTrainedModel(PreTrainedModel):
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
+ """
+ Multi-modality causal language model combining vision and language components.
+ """
def __init__(self, config: MultiModalityConfig):
super().__init__(config)
- vision_config = config.vision_config
- vision_cls = model_name_to_cls(vision_config.cls)
- self.vision_model = vision_cls(**vision_config.params)
+ # Initialize vision model
+ vision_cls = model_name_to_cls(config.vision_config.cls)
+ self.vision_model = vision_cls(**config.vision_config.params)
- aligner_config = config.aligner_config
- aligner_cls = model_name_to_cls(aligner_config.cls)
- self.aligner = aligner_cls(aligner_config.params)
+ # Initialize aligner
+ aligner_cls = model_name_to_cls(config.aligner_config.cls)
+ self.aligner = aligner_cls(config.aligner_config.params)
- gen_vision_config = config.gen_vision_config
- gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
+ # Initialize generative vision model
+ gen_vision_cls = model_name_to_cls(config.gen_vision_config.cls)
self.gen_vision_model = gen_vision_cls()
- gen_aligner_config = config.gen_aligner_config
- gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
- self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
+ # Initialize generative aligner
+ gen_aligner_cls = model_name_to_cls(config.gen_aligner_config.cls)
+ self.gen_aligner = gen_aligner_cls(config.gen_aligner_config.params)
- gen_head_config = config.gen_head_config
- gen_head_cls = model_name_to_cls(gen_head_config.cls)
- self.gen_head = gen_head_cls(gen_head_config.params)
+ # Initialize generative head
+ gen_head_cls = model_name_to_cls(config.gen_head_config.cls)
+ self.gen_head = gen_head_cls(config.gen_head_config.params)
+ # Generative embedding layer
self.gen_embed = torch.nn.Embedding(
- gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
+ config.gen_vision_config.params.image_token_size,
+ config.gen_vision_config.params.n_embed,
)
- language_config = config.language_config
- self.language_model = LlamaForCausalLM(language_config)
+ # Language model
+ self.language_model = LlamaForCausalLM(config.language_config)
def prepare_inputs_embeds(
self,
@@ -227,46 +186,53 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
**kwargs,
):
"""
+ Prepares input embeddings by combining text and image embeddings.
Args:
input_ids (torch.LongTensor): [b, T]
- pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
+ pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
- assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
-
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
-
- bs, n = pixel_values.shape[0:2]
+ bs, n = pixel_values.shape[:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
- # [b x n, T2, D]
- images_embeds = self.aligner(self.vision_model(images))
- # [b x n, T2, D] -> [b, n x T2, D]
- images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
- # [b, n, T2] -> [b, n x T2]
- images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
+ # Process images through vision model and aligner
+ images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D]
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # [b, n x T2, D]
+ images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # [b, n x T2]
- # [b, T, D]
- input_ids[input_ids < 0] = 0 # ignore the image embeddings
- inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
+ # Prepare text embeddings
+ input_ids[input_ids < 0] = 0 # Ignore negative IDs
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # [b, T, D]
- # replace with the image embeddings
+ # Replace text embeddings with image embeddings where applicable
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
-
return inputs_embeds
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
+ """
+ Prepares generative image embeddings.
+
+ Args:
+ image_ids (torch.LongTensor): Image token IDs.
+
+ Returns:
+ torch.Tensor: Generated image embeddings.
+ """
return self.gen_aligner(self.gen_embed(image_ids))
+# Register configurations with Hugging Face's AutoConfig
AutoConfig.register("vision", VisionConfig)
AutoConfig.register("aligner", AlignerConfig)
AutoConfig.register("gen_vision", GenVisionConfig)
AutoConfig.register("gen_aligner", GenAlignerConfig)
AutoConfig.register("gen_head", GenHeadConfig)
AutoConfig.register("multi_modality", MultiModalityConfig)
+
+# Register the multi-modality causal LM model
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)