import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from PIL import Image import numpy as np import random # Load model and processor model_path = "deepseek-ai/Janus-1.3B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) vl_gpt = vl_gpt.to(torch.bfloat16).cuda() vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' # Helper function to set the random seed def set_seed(seed): torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.cuda.manual_seed(seed) # Multimodal Understanding function @torch.inference_mode() def multimodal_understanding(image, question, seed, top_p, temperature): # Clear CUDA cache before generating torch.cuda.empty_cache() set_seed(seed) conversation = [ { "role": "User", "content": f"<image_placeholder>\n{question}", "images": [image], }, {"role": "Assistant", "content": ""}, ] pil_images = [Image.fromarray(image)] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) outputs = vl_gpt.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask, pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, do_sample=False if temperature == 0 else True, use_cache=True, temperature=temperature, top_p=top_p, ) answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) return answer # Generate images function @torch.inference_mode() def generate_image(prompt, seed=None, guidance=5): # Clear CUDA cache and avoid tracking gradients torch.cuda.empty_cache() # Set the seed for reproducible results if seed is not None: set_seed(seed) width = 384 height = 384 parallel_size = 5 with torch.no_grad(): messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='') text = text + vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size) images = unpack(patches, width // 16 * 16, height // 16 * 16) return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)] # Gradio interface with gr.Blocks() as demo: gr.Markdown(value="# Multimodal Understanding") with gr.Row(): image_input = gr.Image() with gr.Column(): question_input = gr.Textbox(label="Question") und_seed_input = gr.Number(label="Seed", precision=0, value=42) top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p") temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature") understanding_button = gr.Button("Chat") understanding_output = gr.Textbox(label="Response") examples_inpainting = gr.Examples( label="Multimodal Understanding examples", examples=[ [ "explain this meme", "images/doge.png", ], [ "Convert the formula into latex code.", "images/equation.png", ], ], inputs=[question_input, image_input], ) gr.Markdown(value="# Text-to-Image Generation") with gr.Row(): cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight") prompt_input = gr.Textbox(label="Prompt") seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345) generation_button = gr.Button("Generate Images") image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300) examples_t2i = gr.Examples( label="Text to image generation examples", examples=[ "Master shifu racoon wearing drip attire as a street gangster.", "A cute and adorable baby fox with big brown eyes...", ], inputs=prompt_input, ) understanding_button.click( multimodal_understanding, inputs=[image_input, question_input, und_seed_input, top_p, temperature], outputs=understanding_output ) generation_button.click( fn=generate_image, inputs=[prompt_input, seed_input, cfg_weight_input], outputs=image_output ) demo.launch(share=True)