mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-04-19 10:09:00 -04:00
Merge a897652664
into 1daa72fa40
This commit is contained in:
commit
4609c4fbda
@ -11,27 +11,50 @@ import time
|
||||
# import spaces # Import spaces for ZeroGPU compatibility
|
||||
|
||||
|
||||
# Load model and processor
|
||||
model_path = "deepseek-ai/Janus-Pro-7B"
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
language_config = config.language_config
|
||||
language_config._attn_implementation = 'eager'
|
||||
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
# Global variables to store model and processor (initially for 7B)
|
||||
vl_gpt = None
|
||||
vl_chat_processor = None
|
||||
tokenizer = None
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
current_model_path = "deepseek-ai/Janus-Pro-7B" # Default model
|
||||
|
||||
def load_model_components(model_path):
|
||||
global vl_gpt, vl_chat_processor, tokenizer, current_model_path # Declare current_model_path as global here
|
||||
|
||||
if vl_gpt is not None and current_model_path == model_path:
|
||||
print(f"Using cached model: {model_path}")
|
||||
return vl_gpt, vl_chat_processor, tokenizer
|
||||
|
||||
print(f"Loading model: {model_path}")
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
language_config = config.language_config
|
||||
language_config._attn_implementation = 'eager'
|
||||
vl_gpt_local = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
language_config=language_config,
|
||||
trust_remote_code=True)
|
||||
if torch.cuda.is_available():
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||
else:
|
||||
vl_gpt = vl_gpt.to(torch.float16)
|
||||
if torch.cuda.is_available():
|
||||
vl_gpt_local = vl_gpt_local.to(torch.bfloat16).cuda()
|
||||
else:
|
||||
vl_gpt_local = vl_gpt_local.to(torch.float16)
|
||||
|
||||
vl_chat_processor_local = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer_local = vl_chat_processor_local.tokenizer
|
||||
|
||||
vl_gpt = vl_gpt_local
|
||||
vl_chat_processor = vl_chat_processor_local
|
||||
tokenizer = tokenizer_local
|
||||
current_model_path = model_path
|
||||
print(f"Model loaded: {model_path}")
|
||||
return vl_gpt, vl_chat_processor, tokenizer
|
||||
|
||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
@torch.inference_mode()
|
||||
# @spaces.GPU(duration=120)
|
||||
# Multimodal Understanding function
|
||||
def multimodal_understanding(image, question, seed, top_p, temperature):
|
||||
def multimodal_understanding(model_name, image, question, seed, top_p, temperature):
|
||||
# Load model based on selection
|
||||
load_model_components(model_name)
|
||||
|
||||
# Clear CUDA cache before generating
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -114,7 +137,6 @@ def generate(input_ids,
|
||||
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
||||
|
||||
|
||||
|
||||
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
||||
shape=[parallel_size, 8, width // patch_size, height // patch_size])
|
||||
|
||||
@ -133,10 +155,10 @@ def unpack(dec, width, height, parallel_size=5):
|
||||
|
||||
@torch.inference_mode()
|
||||
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
||||
def generate_image(prompt,
|
||||
seed=None,
|
||||
guidance=5,
|
||||
t2i_temperature=1.0):
|
||||
def generate_image(model_name, prompt, seed, guidance, t2i_temperature, parallel_size_slider):
|
||||
# Load model based on selection
|
||||
load_model_components(model_name)
|
||||
|
||||
# Clear CUDA cache and avoid tracking gradients
|
||||
torch.cuda.empty_cache()
|
||||
# Set the seed for reproducible results
|
||||
@ -146,7 +168,7 @@ def generate_image(prompt,
|
||||
np.random.seed(seed)
|
||||
width = 384
|
||||
height = 384
|
||||
parallel_size = 5
|
||||
parallel_size = int(parallel_size_slider) # Use slider value for parallel_size
|
||||
|
||||
with torch.no_grad():
|
||||
messages = [{'role': '<|User|>', 'content': prompt},
|
||||
@ -173,7 +195,14 @@ def generate_image(prompt,
|
||||
|
||||
# Gradio interface
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(value="# Multimodal Understanding")
|
||||
gr.Markdown(value="# Multimodal Model Demo: Janus-Pro-7B & 1B")
|
||||
|
||||
model_selector = gr.Dropdown(
|
||||
["deepseek-ai/Janus-Pro-7B", "deepseek-ai/Janus-Pro-1B"],
|
||||
value="deepseek-ai/Janus-Pro-7B", label="Select Model"
|
||||
)
|
||||
|
||||
with gr.Tab("Multimodal Understanding"):
|
||||
with gr.Row():
|
||||
image_input = gr.Image()
|
||||
with gr.Column():
|
||||
@ -200,14 +229,11 @@ with gr.Blocks() as demo:
|
||||
inputs=[question_input, image_input],
|
||||
)
|
||||
|
||||
|
||||
gr.Markdown(value="# Text-to-Image Generation")
|
||||
|
||||
|
||||
|
||||
with gr.Tab("Text-to-Image Generation"):
|
||||
with gr.Row():
|
||||
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
||||
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
|
||||
parallel_size_slider = gr.Slider(minimum=1, maximum=5, value=5, step=1, label="Parallel Size") # New slider
|
||||
|
||||
prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
|
||||
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
||||
@ -231,13 +257,13 @@ with gr.Blocks() as demo:
|
||||
|
||||
understanding_button.click(
|
||||
multimodal_understanding,
|
||||
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
||||
inputs=[model_selector, image_input, question_input, und_seed_input, top_p, temperature], # Added model_selector
|
||||
outputs=understanding_output
|
||||
)
|
||||
|
||||
generation_button.click(
|
||||
fn=generate_image,
|
||||
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
|
||||
inputs=[model_selector, prompt_input, seed_input, cfg_weight_input, t2i_temperature, parallel_size_slider], # Added model_selector and parallel_size_slider
|
||||
outputs=image_output
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user