This commit is contained in:
lastrei 2025-02-01 21:55:03 +08:00 committed by GitHub
commit 4609c4fbda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,27 +11,50 @@ import time
# import spaces # Import spaces for ZeroGPU compatibility
# Load model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
# Global variables to store model and processor (initially for 7B)
vl_gpt = None
vl_chat_processor = None
tokenizer = None
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
current_model_path = "deepseek-ai/Janus-Pro-7B" # Default model
def load_model_components(model_path):
global vl_gpt, vl_chat_processor, tokenizer, current_model_path # Declare current_model_path as global here
if vl_gpt is not None and current_model_path == model_path:
print(f"Using cached model: {model_path}")
return vl_gpt, vl_chat_processor, tokenizer
print(f"Loading model: {model_path}")
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt_local = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True)
if torch.cuda.is_available():
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
vl_gpt = vl_gpt.to(torch.float16)
if torch.cuda.is_available():
vl_gpt_local = vl_gpt_local.to(torch.bfloat16).cuda()
else:
vl_gpt_local = vl_gpt_local.to(torch.float16)
vl_chat_processor_local = VLChatProcessor.from_pretrained(model_path)
tokenizer_local = vl_chat_processor_local.tokenizer
vl_gpt = vl_gpt_local
vl_chat_processor = vl_chat_processor_local
tokenizer = tokenizer_local
current_model_path = model_path
print(f"Model loaded: {model_path}")
return vl_gpt, vl_chat_processor, tokenizer
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
@torch.inference_mode()
# @spaces.GPU(duration=120)
# Multimodal Understanding function
def multimodal_understanding(image, question, seed, top_p, temperature):
def multimodal_understanding(model_name, image, question, seed, top_p, temperature):
# Load model based on selection
load_model_components(model_name)
# Clear CUDA cache before generating
torch.cuda.empty_cache()
@ -114,7 +137,6 @@ def generate(input_ids,
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size])
@ -133,10 +155,10 @@ def unpack(dec, width, height, parallel_size=5):
@torch.inference_mode()
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
def generate_image(prompt,
seed=None,
guidance=5,
t2i_temperature=1.0):
def generate_image(model_name, prompt, seed, guidance, t2i_temperature, parallel_size_slider):
# Load model based on selection
load_model_components(model_name)
# Clear CUDA cache and avoid tracking gradients
torch.cuda.empty_cache()
# Set the seed for reproducible results
@ -146,7 +168,7 @@ def generate_image(prompt,
np.random.seed(seed)
width = 384
height = 384
parallel_size = 5
parallel_size = int(parallel_size_slider) # Use slider value for parallel_size
with torch.no_grad():
messages = [{'role': '<|User|>', 'content': prompt},
@ -173,7 +195,14 @@ def generate_image(prompt,
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown(value="# Multimodal Understanding")
gr.Markdown(value="# Multimodal Model Demo: Janus-Pro-7B & 1B")
model_selector = gr.Dropdown(
["deepseek-ai/Janus-Pro-7B", "deepseek-ai/Janus-Pro-1B"],
value="deepseek-ai/Janus-Pro-7B", label="Select Model"
)
with gr.Tab("Multimodal Understanding"):
with gr.Row():
image_input = gr.Image()
with gr.Column():
@ -200,14 +229,11 @@ with gr.Blocks() as demo:
inputs=[question_input, image_input],
)
gr.Markdown(value="# Text-to-Image Generation")
with gr.Tab("Text-to-Image Generation"):
with gr.Row():
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
parallel_size_slider = gr.Slider(minimum=1, maximum=5, value=5, step=1, label="Parallel Size") # New slider
prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
@ -231,13 +257,13 @@ with gr.Blocks() as demo:
understanding_button.click(
multimodal_understanding,
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
inputs=[model_selector, image_input, question_input, und_seed_input, top_p, temperature], # Added model_selector
outputs=understanding_output
)
generation_button.click(
fn=generate_image,
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
inputs=[model_selector, prompt_input, seed_input, cfg_weight_input, t2i_temperature, parallel_size_slider], # Added model_selector and parallel_size_slider
outputs=image_output
)