mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-04-19 01:59:02 -04:00
Merge 3e49061a00
into 1daa72fa40
This commit is contained in:
commit
31bba0b62e
@ -17,7 +17,6 @@
|
||||
<a href="https://www.deepseek.com/" target="_blank">
|
||||
<img alt="Homepage" src="images/badge.svg" />
|
||||
</a>
|
||||
</a>
|
||||
<a href="https://huggingface.co/deepseek-ai" target="_blank">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
|
||||
</a>
|
||||
@ -66,7 +65,7 @@
|
||||
|
||||
**2024.11.13**: JanusFlow is released, a new unified model with rectified flow for image generation. See [paper](https://arxiv.org/abs/2411.07975), [demo](https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B) and [usage](https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janusflow).
|
||||
|
||||
**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link]( https://github.com/open-compass/VLMEvalKit/pull/541).
|
||||
**2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link](https://github.com/open-compass/VLMEvalKit/pull/541).
|
||||
|
||||
**2024.10.20**: (1) Fix a bug in [tokenizer_config.json](https://huggingface.co/deepseek-ai/Janus-1.3B/blob/main/tokenizer_config.json). The previous version caused classifier-free guidance to not function properly, resulting in relatively poor visual generation quality. (2) Release Gradio demo ([online demo](https://huggingface.co/spaces/deepseek-ai/Janus-1.3B) and [local](#gradio-demo)).
|
||||
|
||||
@ -165,10 +164,10 @@ prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation, images=pil_images, force_batchify=True
|
||||
).to(vl_gpt.device)
|
||||
|
||||
# # run image encoder to get the image embeddings
|
||||
# run image encoder to get the image embeddings
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
# # run the model to get the response
|
||||
# run the model to get the response
|
||||
outputs = vl_gpt.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=prepare_inputs.attention_mask,
|
||||
|
@ -26,23 +26,21 @@ from transformers import (
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
PreTrainedModel,
|
||||
PretrainedConfig,
|
||||
)
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from janus.models.clip_encoder import CLIPVisionTower
|
||||
from janus.models.projector import MlpProjector
|
||||
|
||||
|
||||
class vision_head(torch.nn.Module):
|
||||
class VisionHead(torch.nn.Module):
|
||||
"""
|
||||
Vision head module for processing visual embeddings.
|
||||
"""
|
||||
def __init__(self, params):
|
||||
super().__init__()
|
||||
self.output_mlp_projector = torch.nn.Linear(
|
||||
params.n_embed, params.image_token_embed
|
||||
)
|
||||
self.output_mlp_projector = torch.nn.Linear(params.n_embed, params.image_token_embed)
|
||||
self.vision_activation = torch.nn.GELU()
|
||||
self.vision_head = torch.nn.Linear(
|
||||
params.image_token_embed, params.image_token_size
|
||||
)
|
||||
self.vision_head = torch.nn.Linear(params.image_token_embed, params.image_token_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.output_mlp_projector(x)
|
||||
@ -52,135 +50,91 @@ class vision_head(torch.nn.Module):
|
||||
|
||||
|
||||
def model_name_to_cls(cls_name):
|
||||
if "MlpProjector" in cls_name:
|
||||
cls = MlpProjector
|
||||
"""
|
||||
Maps a class name to its corresponding class.
|
||||
"""
|
||||
mapping = {
|
||||
"MlpProjector": MlpProjector,
|
||||
"CLIPVisionTower": CLIPVisionTower,
|
||||
"vision_head": VisionHead,
|
||||
}
|
||||
|
||||
elif "CLIPVisionTower" in cls_name:
|
||||
cls = CLIPVisionTower
|
||||
|
||||
elif "VQ" in cls_name:
|
||||
if "VQ" in cls_name:
|
||||
from janus.models.vq_model import VQ_models
|
||||
return VQ_models[cls_name]
|
||||
|
||||
cls = VQ_models[cls_name]
|
||||
elif "vision_head" in cls_name:
|
||||
cls = vision_head
|
||||
else:
|
||||
raise ValueError(f"class_name {cls_name} is invalid.")
|
||||
|
||||
cls = mapping.get(cls_name)
|
||||
if cls is None:
|
||||
raise ValueError(f"Invalid class name: {cls_name}")
|
||||
return cls
|
||||
|
||||
|
||||
class VisionConfig(PretrainedConfig):
|
||||
class BaseConfig(PretrainedConfig):
|
||||
"""
|
||||
Base configuration class for multi-modality components.
|
||||
"""
|
||||
model_type = ""
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class VisionConfig(BaseConfig):
|
||||
model_type = "vision"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class AlignerConfig(PretrainedConfig):
|
||||
class AlignerConfig(BaseConfig):
|
||||
model_type = "aligner"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class GenVisionConfig(PretrainedConfig):
|
||||
class GenVisionConfig(BaseConfig):
|
||||
model_type = "gen_vision"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class GenAlignerConfig(PretrainedConfig):
|
||||
class GenAlignerConfig(BaseConfig):
|
||||
model_type = "gen_aligner"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class GenHeadConfig(PretrainedConfig):
|
||||
class GenHeadConfig(BaseConfig):
|
||||
model_type = "gen_head"
|
||||
cls: str = ""
|
||||
params: AttrDict = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.cls = kwargs.get("cls", "")
|
||||
if not isinstance(self.cls, str):
|
||||
self.cls = self.cls.__name__
|
||||
|
||||
self.params = AttrDict(kwargs.get("params", {}))
|
||||
|
||||
|
||||
class MultiModalityConfig(PretrainedConfig):
|
||||
"""
|
||||
Configuration for the multi-modality model.
|
||||
"""
|
||||
model_type = "multi_modality"
|
||||
vision_config: VisionConfig
|
||||
aligner_config: AlignerConfig
|
||||
|
||||
gen_vision_config: GenVisionConfig
|
||||
gen_aligner_config: GenAlignerConfig
|
||||
gen_head_config: GenHeadConfig
|
||||
|
||||
language_config: LlamaConfig
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
vision_config = kwargs.get("vision_config", {})
|
||||
self.vision_config = VisionConfig(**vision_config)
|
||||
|
||||
aligner_config = kwargs.get("aligner_config", {})
|
||||
self.aligner_config = AlignerConfig(**aligner_config)
|
||||
|
||||
gen_vision_config = kwargs.get("gen_vision_config", {})
|
||||
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
|
||||
|
||||
gen_aligner_config = kwargs.get("gen_aligner_config", {})
|
||||
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
|
||||
|
||||
gen_head_config = kwargs.get("gen_head_config", {})
|
||||
self.gen_head_config = GenHeadConfig(**gen_head_config)
|
||||
self.vision_config = VisionConfig(**kwargs.get("vision_config", {}))
|
||||
self.aligner_config = AlignerConfig(**kwargs.get("aligner_config", {}))
|
||||
self.gen_vision_config = GenVisionConfig(**kwargs.get("gen_vision_config", {}))
|
||||
self.gen_aligner_config = GenAlignerConfig(**kwargs.get("gen_aligner_config", {}))
|
||||
self.gen_head_config = GenHeadConfig(**kwargs.get("gen_head_config", {}))
|
||||
|
||||
language_config = kwargs.get("language_config", {})
|
||||
if isinstance(language_config, LlamaConfig):
|
||||
self.language_config = language_config
|
||||
else:
|
||||
self.language_config = LlamaConfig(**language_config)
|
||||
self.language_config = (
|
||||
language_config if isinstance(language_config, LlamaConfig) else LlamaConfig(**language_config)
|
||||
)
|
||||
|
||||
|
||||
class MultiModalityPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
Base class for multi-modality pre-trained models.
|
||||
"""
|
||||
config_class = MultiModalityConfig
|
||||
base_model_prefix = "multi_modality"
|
||||
_no_split_modules = []
|
||||
@ -188,35 +142,40 @@ class MultiModalityPreTrainedModel(PreTrainedModel):
|
||||
|
||||
|
||||
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
"""
|
||||
Multi-modality causal language model combining vision and language components.
|
||||
"""
|
||||
def __init__(self, config: MultiModalityConfig):
|
||||
super().__init__(config)
|
||||
|
||||
vision_config = config.vision_config
|
||||
vision_cls = model_name_to_cls(vision_config.cls)
|
||||
self.vision_model = vision_cls(**vision_config.params)
|
||||
# Initialize vision model
|
||||
vision_cls = model_name_to_cls(config.vision_config.cls)
|
||||
self.vision_model = vision_cls(**config.vision_config.params)
|
||||
|
||||
aligner_config = config.aligner_config
|
||||
aligner_cls = model_name_to_cls(aligner_config.cls)
|
||||
self.aligner = aligner_cls(aligner_config.params)
|
||||
# Initialize aligner
|
||||
aligner_cls = model_name_to_cls(config.aligner_config.cls)
|
||||
self.aligner = aligner_cls(config.aligner_config.params)
|
||||
|
||||
gen_vision_config = config.gen_vision_config
|
||||
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
|
||||
# Initialize generative vision model
|
||||
gen_vision_cls = model_name_to_cls(config.gen_vision_config.cls)
|
||||
self.gen_vision_model = gen_vision_cls()
|
||||
|
||||
gen_aligner_config = config.gen_aligner_config
|
||||
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
|
||||
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
|
||||
# Initialize generative aligner
|
||||
gen_aligner_cls = model_name_to_cls(config.gen_aligner_config.cls)
|
||||
self.gen_aligner = gen_aligner_cls(config.gen_aligner_config.params)
|
||||
|
||||
gen_head_config = config.gen_head_config
|
||||
gen_head_cls = model_name_to_cls(gen_head_config.cls)
|
||||
self.gen_head = gen_head_cls(gen_head_config.params)
|
||||
# Initialize generative head
|
||||
gen_head_cls = model_name_to_cls(config.gen_head_config.cls)
|
||||
self.gen_head = gen_head_cls(config.gen_head_config.params)
|
||||
|
||||
# Generative embedding layer
|
||||
self.gen_embed = torch.nn.Embedding(
|
||||
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
|
||||
config.gen_vision_config.params.image_token_size,
|
||||
config.gen_vision_config.params.n_embed,
|
||||
)
|
||||
|
||||
language_config = config.language_config
|
||||
self.language_model = LlamaForCausalLM(language_config)
|
||||
# Language model
|
||||
self.language_model = LlamaForCausalLM(config.language_config)
|
||||
|
||||
def prepare_inputs_embeds(
|
||||
self,
|
||||
@ -227,6 +186,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Prepares input embeddings by combining text and image embeddings.
|
||||
|
||||
Args:
|
||||
input_ids (torch.LongTensor): [b, T]
|
||||
@ -234,39 +194,45 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
images_seq_mask (torch.BoolTensor): [b, T]
|
||||
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
||||
|
||||
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
||||
|
||||
Returns:
|
||||
input_embeds (torch.Tensor): [b, T, D]
|
||||
"""
|
||||
|
||||
bs, n = pixel_values.shape[0:2]
|
||||
bs, n = pixel_values.shape[:2]
|
||||
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
||||
# [b x n, T2, D]
|
||||
images_embeds = self.aligner(self.vision_model(images))
|
||||
|
||||
# [b x n, T2, D] -> [b, n x T2, D]
|
||||
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
||||
# [b, n, T2] -> [b, n x T2]
|
||||
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
||||
# Process images through vision model and aligner
|
||||
images_embeds = self.aligner(self.vision_model(images)) # [b x n, T2, D]
|
||||
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) # [b, n x T2, D]
|
||||
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") # [b, n x T2]
|
||||
|
||||
# [b, T, D]
|
||||
input_ids[input_ids < 0] = 0 # ignore the image embeddings
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
# Prepare text embeddings
|
||||
input_ids[input_ids < 0] = 0 # Ignore negative IDs
|
||||
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) # [b, T, D]
|
||||
|
||||
# replace with the image embeddings
|
||||
# Replace text embeddings with image embeddings where applicable
|
||||
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
||||
"""
|
||||
Prepares generative image embeddings.
|
||||
|
||||
Args:
|
||||
image_ids (torch.LongTensor): Image token IDs.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Generated image embeddings.
|
||||
"""
|
||||
return self.gen_aligner(self.gen_embed(image_ids))
|
||||
|
||||
|
||||
# Register configurations with Hugging Face's AutoConfig
|
||||
AutoConfig.register("vision", VisionConfig)
|
||||
AutoConfig.register("aligner", AlignerConfig)
|
||||
AutoConfig.register("gen_vision", GenVisionConfig)
|
||||
AutoConfig.register("gen_aligner", GenAlignerConfig)
|
||||
AutoConfig.register("gen_head", GenHeadConfig)
|
||||
AutoConfig.register("multi_modality", MultiModalityConfig)
|
||||
|
||||
# Register the multi-modality causal LM model
|
||||
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)
|
||||
|
Loading…
Reference in New Issue
Block a user