mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-04-19 18:18:57 -04:00
Merge a897652664
into 1daa72fa40
This commit is contained in:
commit
4609c4fbda
@ -11,27 +11,50 @@ import time
|
|||||||
# import spaces # Import spaces for ZeroGPU compatibility
|
# import spaces # Import spaces for ZeroGPU compatibility
|
||||||
|
|
||||||
|
|
||||||
# Load model and processor
|
# Global variables to store model and processor (initially for 7B)
|
||||||
model_path = "deepseek-ai/Janus-Pro-7B"
|
vl_gpt = None
|
||||||
config = AutoConfig.from_pretrained(model_path)
|
vl_chat_processor = None
|
||||||
language_config = config.language_config
|
tokenizer = None
|
||||||
language_config._attn_implementation = 'eager'
|
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
current_model_path = "deepseek-ai/Janus-Pro-7B" # Default model
|
||||||
|
|
||||||
|
def load_model_components(model_path):
|
||||||
|
global vl_gpt, vl_chat_processor, tokenizer, current_model_path # Declare current_model_path as global here
|
||||||
|
|
||||||
|
if vl_gpt is not None and current_model_path == model_path:
|
||||||
|
print(f"Using cached model: {model_path}")
|
||||||
|
return vl_gpt, vl_chat_processor, tokenizer
|
||||||
|
|
||||||
|
print(f"Loading model: {model_path}")
|
||||||
|
config = AutoConfig.from_pretrained(model_path)
|
||||||
|
language_config = config.language_config
|
||||||
|
language_config._attn_implementation = 'eager'
|
||||||
|
vl_gpt_local = AutoModelForCausalLM.from_pretrained(model_path,
|
||||||
language_config=language_config,
|
language_config=language_config,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
vl_gpt_local = vl_gpt_local.to(torch.bfloat16).cuda()
|
||||||
else:
|
else:
|
||||||
vl_gpt = vl_gpt.to(torch.float16)
|
vl_gpt_local = vl_gpt_local.to(torch.float16)
|
||||||
|
|
||||||
|
vl_chat_processor_local = VLChatProcessor.from_pretrained(model_path)
|
||||||
|
tokenizer_local = vl_chat_processor_local.tokenizer
|
||||||
|
|
||||||
|
vl_gpt = vl_gpt_local
|
||||||
|
vl_chat_processor = vl_chat_processor_local
|
||||||
|
tokenizer = tokenizer_local
|
||||||
|
current_model_path = model_path
|
||||||
|
print(f"Model loaded: {model_path}")
|
||||||
|
return vl_gpt, vl_chat_processor, tokenizer
|
||||||
|
|
||||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
|
||||||
tokenizer = vl_chat_processor.tokenizer
|
|
||||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
# @spaces.GPU(duration=120)
|
# @spaces.GPU(duration=120)
|
||||||
# Multimodal Understanding function
|
# Multimodal Understanding function
|
||||||
def multimodal_understanding(image, question, seed, top_p, temperature):
|
def multimodal_understanding(model_name, image, question, seed, top_p, temperature):
|
||||||
|
# Load model based on selection
|
||||||
|
load_model_components(model_name)
|
||||||
|
|
||||||
# Clear CUDA cache before generating
|
# Clear CUDA cache before generating
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -114,7 +137,6 @@ def generate(input_ids,
|
|||||||
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
||||||
shape=[parallel_size, 8, width // patch_size, height // patch_size])
|
shape=[parallel_size, 8, width // patch_size, height // patch_size])
|
||||||
|
|
||||||
@ -133,10 +155,10 @@ def unpack(dec, width, height, parallel_size=5):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
||||||
def generate_image(prompt,
|
def generate_image(model_name, prompt, seed, guidance, t2i_temperature, parallel_size_slider):
|
||||||
seed=None,
|
# Load model based on selection
|
||||||
guidance=5,
|
load_model_components(model_name)
|
||||||
t2i_temperature=1.0):
|
|
||||||
# Clear CUDA cache and avoid tracking gradients
|
# Clear CUDA cache and avoid tracking gradients
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# Set the seed for reproducible results
|
# Set the seed for reproducible results
|
||||||
@ -146,7 +168,7 @@ def generate_image(prompt,
|
|||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
width = 384
|
width = 384
|
||||||
height = 384
|
height = 384
|
||||||
parallel_size = 5
|
parallel_size = int(parallel_size_slider) # Use slider value for parallel_size
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
messages = [{'role': '<|User|>', 'content': prompt},
|
messages = [{'role': '<|User|>', 'content': prompt},
|
||||||
@ -173,7 +195,14 @@ def generate_image(prompt,
|
|||||||
|
|
||||||
# Gradio interface
|
# Gradio interface
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown(value="# Multimodal Understanding")
|
gr.Markdown(value="# Multimodal Model Demo: Janus-Pro-7B & 1B")
|
||||||
|
|
||||||
|
model_selector = gr.Dropdown(
|
||||||
|
["deepseek-ai/Janus-Pro-7B", "deepseek-ai/Janus-Pro-1B"],
|
||||||
|
value="deepseek-ai/Janus-Pro-7B", label="Select Model"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Tab("Multimodal Understanding"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
image_input = gr.Image()
|
image_input = gr.Image()
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -200,14 +229,11 @@ with gr.Blocks() as demo:
|
|||||||
inputs=[question_input, image_input],
|
inputs=[question_input, image_input],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Tab("Text-to-Image Generation"):
|
||||||
gr.Markdown(value="# Text-to-Image Generation")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
||||||
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
|
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
|
||||||
|
parallel_size_slider = gr.Slider(minimum=1, maximum=5, value=5, step=1, label="Parallel Size") # New slider
|
||||||
|
|
||||||
prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
|
prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!)")
|
||||||
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
||||||
@ -231,13 +257,13 @@ with gr.Blocks() as demo:
|
|||||||
|
|
||||||
understanding_button.click(
|
understanding_button.click(
|
||||||
multimodal_understanding,
|
multimodal_understanding,
|
||||||
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
inputs=[model_selector, image_input, question_input, und_seed_input, top_p, temperature], # Added model_selector
|
||||||
outputs=understanding_output
|
outputs=understanding_output
|
||||||
)
|
)
|
||||||
|
|
||||||
generation_button.click(
|
generation_button.click(
|
||||||
fn=generate_image,
|
fn=generate_image,
|
||||||
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
|
inputs=[model_selector, prompt_input, seed_input, cfg_weight_input, t2i_temperature, parallel_size_slider], # Added model_selector and parallel_size_slider
|
||||||
outputs=image_output
|
outputs=image_output
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user