mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-04-19 01:59:02 -04:00
171 lines
5.7 KiB
Python
171 lines
5.7 KiB
Python
import gradio as gr
|
|
import torch
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
|
from PIL import Image
|
|
import numpy as np
|
|
import random
|
|
|
|
# Load model and processor
|
|
model_path = "deepseek-ai/Janus-1.3B"
|
|
config = AutoConfig.from_pretrained(model_path)
|
|
language_config = config.language_config
|
|
language_config._attn_implementation = 'eager'
|
|
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
|
language_config=language_config,
|
|
trust_remote_code=True)
|
|
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
|
|
|
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
|
tokenizer = vl_chat_processor.tokenizer
|
|
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
# Helper function to set the random seed
|
|
def set_seed(seed):
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
|
|
# Multimodal Understanding function
|
|
@torch.inference_mode()
|
|
def multimodal_understanding(image, question, seed, top_p, temperature):
|
|
# Clear CUDA cache before generating
|
|
torch.cuda.empty_cache()
|
|
|
|
set_seed(seed)
|
|
|
|
conversation = [
|
|
{
|
|
"role": "User",
|
|
"content": f"<image_placeholder>\n{question}",
|
|
"images": [image],
|
|
},
|
|
{"role": "Assistant", "content": ""},
|
|
]
|
|
|
|
pil_images = [Image.fromarray(image)]
|
|
prepare_inputs = vl_chat_processor(
|
|
conversations=conversation, images=pil_images, force_batchify=True
|
|
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
|
|
|
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
|
|
|
outputs = vl_gpt.language_model.generate(
|
|
inputs_embeds=inputs_embeds,
|
|
attention_mask=prepare_inputs.attention_mask,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
bos_token_id=tokenizer.bos_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
max_new_tokens=512,
|
|
do_sample=False if temperature == 0 else True,
|
|
use_cache=True,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
)
|
|
|
|
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
|
return answer
|
|
|
|
|
|
# Generate images function
|
|
@torch.inference_mode()
|
|
def generate_image(prompt, seed=None, guidance=5):
|
|
# Clear CUDA cache and avoid tracking gradients
|
|
torch.cuda.empty_cache()
|
|
|
|
# Set the seed for reproducible results
|
|
if seed is not None:
|
|
set_seed(seed)
|
|
|
|
width = 384
|
|
height = 384
|
|
parallel_size = 5
|
|
|
|
with torch.no_grad():
|
|
messages = [{'role': 'User', 'content': prompt},
|
|
{'role': 'Assistant', 'content': ''}]
|
|
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
|
sft_format=vl_chat_processor.sft_format,
|
|
system_prompt='')
|
|
text = text + vl_chat_processor.image_start_tag
|
|
input_ids = torch.LongTensor(tokenizer.encode(text))
|
|
output, patches = generate(input_ids,
|
|
width // 16 * 16,
|
|
height // 16 * 16,
|
|
cfg_weight=guidance,
|
|
parallel_size=parallel_size)
|
|
images = unpack(patches,
|
|
width // 16 * 16,
|
|
height // 16 * 16)
|
|
|
|
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
|
|
|
|
|
|
# Gradio interface
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown(value="# Multimodal Understanding")
|
|
|
|
with gr.Row():
|
|
image_input = gr.Image()
|
|
with gr.Column():
|
|
question_input = gr.Textbox(label="Question")
|
|
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
|
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
|
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
|
|
|
understanding_button = gr.Button("Chat")
|
|
understanding_output = gr.Textbox(label="Response")
|
|
|
|
examples_inpainting = gr.Examples(
|
|
label="Multimodal Understanding examples",
|
|
examples=[
|
|
[
|
|
"explain this meme",
|
|
"images/doge.png",
|
|
],
|
|
[
|
|
"Convert the formula into latex code.",
|
|
"images/equation.png",
|
|
],
|
|
],
|
|
inputs=[question_input, image_input],
|
|
)
|
|
|
|
gr.Markdown(value="# Text-to-Image Generation")
|
|
|
|
with gr.Row():
|
|
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
|
|
|
prompt_input = gr.Textbox(label="Prompt")
|
|
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
|
|
|
generation_button = gr.Button("Generate Images")
|
|
|
|
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
|
|
|
examples_t2i = gr.Examples(
|
|
label="Text to image generation examples",
|
|
examples=[
|
|
"Master shifu racoon wearing drip attire as a street gangster.",
|
|
"A cute and adorable baby fox with big brown eyes...",
|
|
],
|
|
inputs=prompt_input,
|
|
)
|
|
|
|
understanding_button.click(
|
|
multimodal_understanding,
|
|
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
|
outputs=understanding_output
|
|
)
|
|
|
|
generation_button.click(
|
|
fn=generate_image,
|
|
inputs=[prompt_input, seed_input, cfg_weight_input],
|
|
outputs=image_output
|
|
)
|
|
|
|
demo.launch(share=True)
|