diff --git a/README.md b/README.md index 87e1f8f..aff55b7 100755 --- a/README.md +++ b/README.md @@ -56,12 +56,14 @@ 📜 License | 📖 Citation
- 🤗 Online Demo (Janus, JanusFlow) + 🤗 Online Demo (Janus-Pro-7B, Janus, JanusFlow)

## News +**2025.01.27**: Janus-Pro is released, an advanced version of Janus, improving both multimodal understanding and visual generation significantly. See [paper](./janus_pro_tech_report.pdf) + **2024.11.13**: JanusFlow is released, a new unified model with rectified flow for image generation. See [paper](https://arxiv.org/abs/2411.07975), [demo](https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B) and [usage](https://github.com/deepseek-ai/Janus?tab=readme-ov-file#janusflow). **2024.10.23**: Evaluation code for reproducing the multimodal understanding results from the paper has been added to VLMEvalKit. Please refer to [this link]( https://github.com/open-compass/VLMEvalKit/pull/541). @@ -71,6 +73,16 @@ ## 1. Introduction +Janus-Pro: Unified Multimodal Understanding and +Generation with Data and Model Scaling + +**Janus-Pro** is an advanced version of the previous work Janus. Specifically, Janus-Pro incorporates (1) an optimized training strategy, (2) expanded training data, and (3) scaling to larger model size. With these improvements, Janus-Pro achieves significant advancements in both multimodal understanding and text-to-image instruction-following capabilities, while also enhancing the stability of text-to-image generation. + +
+image +
+ + Janus: Decoupling Visual Encoding for Unified Multimodal Understanding and Generation **Janus** is a novel autoregressive framework that unifies multimodal understanding and generation. It addresses the limitations of previous approaches by decoupling visual encoding into separate pathways, while still utilizing a single, unified transformer architecture for processing. The decoupling not only alleviates the conflict between the visual encoder’s roles in understanding and generation, but also enhances the framework’s flexibility. Janus surpasses previous unified model and matches or exceeds the performance of task-specific models. The simplicity, high flexibility, and effectiveness of Janus make it a strong candidate for next-generation unified multimodal models. @@ -100,11 +112,197 @@ permitted under these terms. |-----------------------|-----------------|-----------------------------------------------------------------------------| | Janus-1.3B | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/Janus-1.3B) | | JanusFlow-1.3B | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/JanusFlow-1.3B) | - +| Janus-Pro-1B | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/Janus-Pro-1B) | +| Janus-Pro-7B | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/Janus-Pro-7B) | ## 3. Quick Start +
+

Janus-Pro

+ +### Installation + +On the basis of `Python >= 3.8` environment, install the necessary dependencies by running the following command: + +```shell +pip install -e . +``` + + +### Simple Inference Example + +#### Multimodal Understanding +```python + +import torch +from transformers import AutoModelForCausalLM +from janus.models import MultiModalityCausalLM, VLChatProcessor +from janus.utils.io import load_pil_images + +# specify the path to the model +model_path = "deepseek-ai/Janus-Pro-7B" +vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer + +vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True +) +vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + +conversation = [ + { + "role": "<|User|>", + "content": f"\n{question}", + "images": [image], + }, + {"role": "<|Assistant|>", "content": ""}, +] + +# load images and prepare for inputs +pil_images = load_pil_images(conversation) +prepare_inputs = vl_chat_processor( + conversations=conversation, images=pil_images, force_batchify=True +).to(vl_gpt.device) + +# # run image encoder to get the image embeddings +inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + +# # run the model to get the response +outputs = vl_gpt.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=512, + do_sample=False, + use_cache=True, +) + +answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) +print(f"{prepare_inputs['sft_format'][0]}", answer) + +``` + +#### Text-to-Image Generation +```python +import os +import PIL.Image +import torch +import numpy as np +from transformers import AutoModelForCausalLM +from janus.models import MultiModalityCausalLM, VLChatProcessor + + +# specify the path to the model +model_path = "deepseek-ai/Janus-Pro-7B" +vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer + +vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True +) +vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + +conversation = [ + { + "role": "<|User|>", + "content": "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair", + }, + {"role": "<|Assistant|>", "content": ""}, +] + +sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( + conversations=conversation, + sft_format=vl_chat_processor.sft_format, + system_prompt="", +) +prompt = sft_format + vl_chat_processor.image_start_tag + + +@torch.inference_mode() +def generate( + mmgpt: MultiModalityCausalLM, + vl_chat_processor: VLChatProcessor, + prompt: str, + temperature: float = 1, + parallel_size: int = 16, + cfg_weight: float = 5, + image_token_num_per_image: int = 576, + img_size: int = 384, + patch_size: int = 16, +): + input_ids = vl_chat_processor.tokenizer.encode(prompt) + input_ids = torch.LongTensor(input_ids) + + tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int).cuda() + for i in range(parallel_size*2): + tokens[i, :] = input_ids + if i % 2 != 0: + tokens[i, 1:-1] = vl_chat_processor.pad_id + + inputs_embeds = mmgpt.language_model.get_input_embeddings()(tokens) + + generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda() + + for i in range(image_token_num_per_image): + outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None) + hidden_states = outputs.last_hidden_state + + logits = mmgpt.gen_head(hidden_states[:, -1, :]) + logit_cond = logits[0::2, :] + logit_uncond = logits[1::2, :] + + logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond) + probs = torch.softmax(logits / temperature, dim=-1) + + next_token = torch.multinomial(probs, num_samples=1) + generated_tokens[:, i] = next_token.squeeze(dim=-1) + + next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) + img_embeds = mmgpt.prepare_gen_img_embeds(next_token) + inputs_embeds = img_embeds.unsqueeze(dim=1) + + + dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size]) + dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) + + dec = np.clip((dec + 1) / 2 * 255, 0, 255) + + visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) + visual_img[:, :, :] = dec + + os.makedirs('generated_samples', exist_ok=True) + for i in range(parallel_size): + save_path = os.path.join('generated_samples', "img_{}.jpg".format(i)) + PIL.Image.fromarray(visual_img[i]).save(save_path) + + +generate( + vl_gpt, + vl_chat_processor, + prompt, +) +``` + +### Gradio Demo +We have deployed online demo in [Huggingface](https://huggingface.co/spaces/deepseek-ai/Janus-Pro-7B). + + +For the local gradio demo, you can run with the following command: + +``` +pip install -e .[gradio] + +python demo/app_januspro.py +``` + +Have Fun! +``` +
+ +

Janus

@@ -519,6 +717,12 @@ This code repository is licensed under [the MIT License](https://github.com/deep ## 5. Citation ```bibtex +@misc{chen2025januspro, + title={Janus-Pro: Unified Multimodal Understanding and Generation with Data and Model Scaling}, + author={Xiaokang Chen and Zhiyu Wu and Xingchao Liu and Zizheng Pan and Wen Liu and Zhenda Xie and Xingkai Yu and Chong Ruan}, + year={2025}, +} + @article{wu2024janus, title={Janus: Decoupling visual encoding for unified multimodal understanding and generation}, author={Wu, Chengyue and Chen, Xiaokang and Wu, Zhiyu and Ma, Yiyang and Liu, Xingchao and Pan, Zizheng and Liu, Wen and Xie, Zhenda and Yu, Xingkai and Ruan, Chong and others}, diff --git a/demo/app_januspro.py b/demo/app_januspro.py new file mode 100644 index 0000000..702e58e --- /dev/null +++ b/demo/app_januspro.py @@ -0,0 +1,245 @@ +import gradio as gr +import torch +from transformers import AutoConfig, AutoModelForCausalLM +from janus.models import MultiModalityCausalLM, VLChatProcessor +from janus.utils.io import load_pil_images +from PIL import Image + +import numpy as np +import os +import time +# import spaces # Import spaces for ZeroGPU compatibility + + +# Load model and processor +model_path = "deepseek-ai/Janus-Pro-7B" +config = AutoConfig.from_pretrained(model_path) +language_config = config.language_config +language_config._attn_implementation = 'eager' +vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, + language_config=language_config, + trust_remote_code=True) +if torch.cuda.is_available(): + vl_gpt = vl_gpt.to(torch.bfloat16).cuda() +else: + vl_gpt = vl_gpt.to(torch.float16) + +vl_chat_processor = VLChatProcessor.from_pretrained(model_path) +tokenizer = vl_chat_processor.tokenizer +cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +@torch.inference_mode() +# @spaces.GPU(duration=120) +# Multimodal Understanding function +def multimodal_understanding(image, question, seed, top_p, temperature): + # Clear CUDA cache before generating + torch.cuda.empty_cache() + + # set seed + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed(seed) + + conversation = [ + { + "role": "<|User|>", + "content": f"\n{question}", + "images": [image], + }, + {"role": "<|Assistant|>", "content": ""}, + ] + + pil_images = [Image.fromarray(image)] + prepare_inputs = vl_chat_processor( + conversations=conversation, images=pil_images, force_batchify=True + ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) + + + inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) + + outputs = vl_gpt.language_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=prepare_inputs.attention_mask, + pad_token_id=tokenizer.eos_token_id, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + max_new_tokens=512, + do_sample=False if temperature == 0 else True, + use_cache=True, + temperature=temperature, + top_p=top_p, + ) + + answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) + return answer + + +def generate(input_ids, + width, + height, + temperature: float = 1, + parallel_size: int = 5, + cfg_weight: float = 5, + image_token_num_per_image: int = 576, + patch_size: int = 16): + # Clear CUDA cache before generating + torch.cuda.empty_cache() + + tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) + for i in range(parallel_size * 2): + tokens[i, :] = input_ids + if i % 2 != 0: + tokens[i, 1:-1] = vl_chat_processor.pad_id + inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) + generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) + + pkv = None + for i in range(image_token_num_per_image): + with torch.no_grad(): + outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, + use_cache=True, + past_key_values=pkv) + pkv = outputs.past_key_values + hidden_states = outputs.last_hidden_state + logits = vl_gpt.gen_head(hidden_states[:, -1, :]) + logit_cond = logits[0::2, :] + logit_uncond = logits[1::2, :] + logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) + probs = torch.softmax(logits / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + generated_tokens[:, i] = next_token.squeeze(dim=-1) + next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) + + img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) + inputs_embeds = img_embeds.unsqueeze(dim=1) + + + + patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), + shape=[parallel_size, 8, width // patch_size, height // patch_size]) + + return generated_tokens.to(dtype=torch.int), patches + +def unpack(dec, width, height, parallel_size=5): + dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) + dec = np.clip((dec + 1) / 2 * 255, 0, 255) + + visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) + visual_img[:, :, :] = dec + + return visual_img + + + +@torch.inference_mode() +# @spaces.GPU(duration=120) # Specify a duration to avoid timeout +def generate_image(prompt, + seed=None, + guidance=5, + t2i_temperature=1.0): + # Clear CUDA cache and avoid tracking gradients + torch.cuda.empty_cache() + # Set the seed for reproducible results + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + width = 384 + height = 384 + parallel_size = 5 + + with torch.no_grad(): + messages = [{'role': '<|User|>', 'content': prompt}, + {'role': '<|Assistant|>', 'content': ''}] + text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages, + sft_format=vl_chat_processor.sft_format, + system_prompt='') + text = text + vl_chat_processor.image_start_tag + + input_ids = torch.LongTensor(tokenizer.encode(text)) + output, patches = generate(input_ids, + width // 16 * 16, + height // 16 * 16, + cfg_weight=guidance, + parallel_size=parallel_size, + temperature=t2i_temperature) + images = unpack(patches, + width // 16 * 16, + height // 16 * 16, + parallel_size=parallel_size) + + return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)] + + +# Gradio interface +with gr.Blocks() as demo: + gr.Markdown(value="# Multimodal Understanding") + with gr.Row(): + image_input = gr.Image() + with gr.Column(): + question_input = gr.Textbox(label="Question") + und_seed_input = gr.Number(label="Seed", precision=0, value=42) + top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p") + temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature") + + understanding_button = gr.Button("Chat") + understanding_output = gr.Textbox(label="Response") + + examples_inpainting = gr.Examples( + label="Multimodal Understanding examples", + examples=[ + [ + "explain this meme", + "images/doge.png", + ], + [ + "Convert the formula into latex code.", + "images/equation.png", + ], + ], + inputs=[question_input, image_input], + ) + + + gr.Markdown(value="# Text-to-Image Generation") + + + + with gr.Row(): + cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight") + t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature") + + prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)") + seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345) + + generation_button = gr.Button("Generate Images") + + image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300) + + examples_t2i = gr.Examples( + label="Text to image generation examples.", + examples=[ + "Master shifu racoon wearing drip attire as a street gangster.", + "The face of a beautiful girl", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A glass of red wine on a reflective surface.", + "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.", + "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.", + ], + inputs=prompt_input, + ) + + understanding_button.click( + multimodal_understanding, + inputs=[image_input, question_input, und_seed_input, top_p, temperature], + outputs=understanding_output + ) + + generation_button.click( + fn=generate_image, + inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature], + outputs=image_output + ) + +demo.launch(share=True) +# demo.queue(concurrency_count=1, max_size=10).launch(server_name="0.0.0.0", server_port=37906, root_path="/path") \ No newline at end of file diff --git a/images/teaser_januspro.png b/images/teaser_januspro.png new file mode 100644 index 0000000..c203a0a Binary files /dev/null and b/images/teaser_januspro.png differ diff --git a/janus/models/processing_vlm.py b/janus/models/processing_vlm.py index 97003d0..eba6895 100644 --- a/janus/models/processing_vlm.py +++ b/janus/models/processing_vlm.py @@ -88,6 +88,7 @@ class VLChatProcessor(ProcessorMixin): image_tag: str = "", image_start_tag: str = "", image_end_tag: str = "", + pad_tag: str = "<|▁pad▁|>", num_image_tokens: int = 576, add_special_token: bool = False, sft_format: str = "deepseek", @@ -108,6 +109,7 @@ class VLChatProcessor(ProcessorMixin): self.image_tag = image_tag self.image_start_tag = image_start_tag self.image_end_tag = image_end_tag + self.pad_tag = pad_tag self.num_image_tokens = num_image_tokens self.add_special_token = add_special_token @@ -203,9 +205,10 @@ class VLChatProcessor(ProcessorMixin): @property def pad_id(self): - pad_id = self.tokenizer.pad_token_id - if pad_id is None: - pad_id = self.tokenizer.eos_token_id + pad_id = self.tokenizer.vocab.get(self.pad_tag) + # pad_id = self.tokenizer.pad_token_id + # if pad_id is None: + # pad_id = self.tokenizer.eos_token_id return pad_id diff --git a/janus/utils/conversation.py b/janus/utils/conversation.py index 711512f..98609d0 100644 --- a/janus/utils/conversation.py +++ b/janus/utils/conversation.py @@ -275,7 +275,7 @@ register_conv_template( # deepseek template register_conv_template( Conversation( - name="deepseek", + name="deepseek_old", system_template="{system_message}", # system_message="You are a helpful assistant. Please answer truthfully and write out your " # "thinking step by step to be sure you get the right answer.", @@ -290,6 +290,23 @@ register_conv_template( stop_str=["User:", "<|end▁of▁sentence|>"], ) ) +register_conv_template( + Conversation( + name="deepseek", + system_template="{system_message}", + # system_message="You are a helpful assistant. Please answer truthfully and write out your " + # "thinking step by step to be sure you get the right answer.", + system_message="", + roles=("<|User|>", "<|Assistant|>"), + messages=(), + offset=0, + sep_style=SeparatorStyle.DeepSeek, + sep="\n\n", + sep2="<|end▁of▁sentence|>", + stop_token_ids=[100001], + stop_str=["<|User|>", "<|end▁of▁sentence|>"] + ) +) register_conv_template( Conversation(