diff --git a/demo/app.py b/demo/app.py index 7dbc59f..3a2f91a 100644 --- a/demo/app.py +++ b/demo/app.py @@ -6,6 +6,15 @@ from PIL import Image import numpy as np +# Device and dtype configuration +def get_device_and_dtype(): + if torch.cuda.is_available(): + return 'cuda', torch.bfloat16 # CUDA设备使用bfloat16 + elif torch.backends.mps.is_available(): + return 'mps', torch.float32 + return 'cpu', torch.float32 + +device, dtype = get_device_and_dtype() # Load model and processor model_path = "deepseek-ai/Janus-1.3B" @@ -15,22 +24,25 @@ language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) -vl_gpt = vl_gpt.to(torch.bfloat16).cuda() +vl_gpt = vl_gpt.to(dtype).to(device) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' + # Multimodal Understanding function @torch.inference_mode() -# Multimodal Understanding function def multimodal_understanding(image, question, seed, top_p, temperature): - # Clear CUDA cache before generating - torch.cuda.empty_cache() + # Clear CUDA cache if using CUDA + if device == 'cuda': + torch.cuda.empty_cache() # set seed torch.manual_seed(seed) np.random.seed(seed) - torch.cuda.manual_seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) + elif device == 'mps': + torch.mps.manual_seed(seed) conversation = [ { @@ -44,8 +56,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature): pil_images = [Image.fromarray(image)] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True - ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) - + ).to(device, dtype=dtype) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) @@ -74,16 +85,17 @@ def generate(input_ids, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): - # Clear CUDA cache before generating - torch.cuda.empty_cache() + # Clear CUDA cache if using CUDA + if device == 'cuda': + torch.cuda.empty_cache() - tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) + tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) - generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) + generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device) pkv = None for i in range(image_token_num_per_image): @@ -123,11 +135,15 @@ def generate_image(prompt, seed=None, guidance=5): # Clear CUDA cache and avoid tracking gradients - torch.cuda.empty_cache() + if device == 'cuda': + torch.cuda.empty_cache() # Set the seed for reproducible results if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) + elif device == 'mps': + torch.mps.manual_seed(seed) np.random.seed(seed) width = 384 height = 384 diff --git a/demo/app_janusflow.py b/demo/app_janusflow.py index 4777196..2975c62 100644 --- a/demo/app_janusflow.py +++ b/demo/app_janusflow.py @@ -5,7 +5,16 @@ from PIL import Image from diffusers.models import AutoencoderKL import numpy as np -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' +# 设置设备选择逻辑 +if torch.cuda.is_available(): + device = 'cuda' + dtype = torch.bfloat16 +elif torch.backends.mps.is_available(): + device = 'mps' + dtype = torch.float32 # MPS设备使用float32 +else: + device = 'cpu' + dtype = torch.float32 # Load model and processor model_path = "deepseek-ai/JanusFlow-1.3B" @@ -13,23 +22,24 @@ vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer vl_gpt = MultiModalityCausalLM.from_pretrained(model_path) -vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval() +vl_gpt = vl_gpt.to(dtype).to(device).eval() # remember to use bfloat16 dtype, this vae doesn't work with fp16 vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") -vae = vae.to(torch.bfloat16).to(cuda_device).eval() +vae = vae.to(dtype).to(device).eval() # Multimodal Understanding function @torch.inference_mode() -# Multimodal Understanding function def multimodal_understanding(image, question, seed, top_p, temperature): # Clear CUDA cache before generating - torch.cuda.empty_cache() + if device == 'cuda': + torch.cuda.empty_cache() # set seed torch.manual_seed(seed) np.random.seed(seed) - torch.cuda.manual_seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) conversation = [ { @@ -43,8 +53,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature): pil_images = [Image.fromarray(image)] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True - ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) - + ).to(device, dtype=dtype) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) @@ -73,7 +82,7 @@ def generate( num_inference_steps: int = 30 ): # we generate 5 images at a time, *2 for CFG - tokens = torch.stack([input_ids] * 10).cuda() + tokens = torch.stack([input_ids] * 10).to(device) tokens[5:, 1:] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) print(inputs_embeds.shape) @@ -83,13 +92,13 @@ def generate( # generate with rectified flow ode # step 1: encode with vision_gen_enc - z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda() + z = torch.randn((5, 4, 48, 48), dtype=dtype).to(device) dt = 1.0 / num_inference_steps - dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt + dt = torch.zeros_like(z).to(device).to(dtype) + dt # step 2: run ode - attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device) + attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(device) attention_mask[5:, 1:inputs_embeds.shape[1]] = 0 attention_mask = attention_mask.int() for step in range(num_inference_steps): @@ -108,8 +117,7 @@ def generate( if step == 0: outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, use_cache=True, - attention_mask=attention_mask, - past_key_values=None) + attention_mask=attention_mask) past_key_values = [] for kv_cache in past_key_values: k, v = kv_cache[0], kv_cache[1] @@ -118,8 +126,7 @@ def generate( else: outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb, use_cache=True, - attention_mask=attention_mask, - past_key_values=past_key_values) + attention_mask=attention_mask) hidden_states = outputs.last_hidden_state # transform hidden_states back to v @@ -153,12 +160,14 @@ def generate_image(prompt, seed=None, guidance=5, num_inference_steps=30): - # Clear CUDA cache and avoid tracking gradients - torch.cuda.empty_cache() + # Clear CUDA cache if using CUDA device + if device == 'cuda': + torch.cuda.empty_cache() # Set the seed for reproducible results if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) np.random.seed(seed) with torch.no_grad(): diff --git a/demo/app_januspro.py b/demo/app_januspro.py index 702e58e..1e6c65f 100644 --- a/demo/app_januspro.py +++ b/demo/app_januspro.py @@ -8,8 +8,17 @@ from PIL import Image import numpy as np import os import time -# import spaces # Import spaces for ZeroGPU compatibility +# Device and dtype configuration +if torch.cuda.is_available(): + device = 'cuda' + dtype = torch.bfloat16 +elif torch.backends.mps.is_available(): + device = 'mps' + dtype = torch.float32 # MPS设备使用float32 +else: + device = 'cpu' + dtype = torch.float32 # CPU设备使用float32 # Load model and processor model_path = "deepseek-ai/Janus-Pro-7B" @@ -19,26 +28,26 @@ language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) -if torch.cuda.is_available(): - vl_gpt = vl_gpt.to(torch.bfloat16).cuda() -else: - vl_gpt = vl_gpt.to(torch.float16) +vl_gpt = vl_gpt.to(dtype).to(device) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' @torch.inference_mode() # @spaces.GPU(duration=120) # Multimodal Understanding function def multimodal_understanding(image, question, seed, top_p, temperature): - # Clear CUDA cache before generating - torch.cuda.empty_cache() + # Clear device cache + if device == 'cuda': + torch.cuda.empty_cache() + elif device == 'mps': + torch.mps.empty_cache() # set seed torch.manual_seed(seed) np.random.seed(seed) - torch.cuda.manual_seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) conversation = [ { @@ -52,8 +61,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature): pil_images = [Image.fromarray(image)] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True - ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) - + ).to(device, dtype=dtype) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) @@ -82,16 +90,19 @@ def generate(input_ids, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): - # Clear CUDA cache before generating - torch.cuda.empty_cache() + # Clear device cache + if device == 'cuda': + torch.cuda.empty_cache() + elif device == 'mps': + torch.mps.empty_cache() - tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) + tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) - generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) + generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device) pkv = None for i in range(image_token_num_per_image): @@ -133,17 +144,24 @@ def unpack(dec, width, height, parallel_size=5): @torch.inference_mode() # @spaces.GPU(duration=120) # Specify a duration to avoid timeout +@torch.inference_mode() def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0): - # Clear CUDA cache and avoid tracking gradients - torch.cuda.empty_cache() + # Clear device cache + if device == 'cuda': + torch.cuda.empty_cache() + elif device == 'mps': + torch.mps.empty_cache() + # Set the seed for reproducible results if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) np.random.seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) + width = 384 height = 384 parallel_size = 5 diff --git a/demo/fastapi_app.py b/demo/fastapi_app.py index c2e5710..6833d85 100644 --- a/demo/fastapi_app.py +++ b/demo/fastapi_app.py @@ -9,6 +9,16 @@ import io app = FastAPI() +# Device and dtype configuration +def get_device_and_dtype(): + if torch.cuda.is_available(): + return 'cuda', torch.bfloat16 # CUDA设备使用bfloat16 + elif torch.backends.mps.is_available(): + return 'mps', torch.float32 + return 'cpu', torch.float32 + +device, dtype = get_device_and_dtype() + # Load model and processor model_path = "deepseek-ai/Janus-1.3B" config = AutoConfig.from_pretrained(model_path) @@ -17,19 +27,25 @@ language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) -vl_gpt = vl_gpt.to(torch.bfloat16).cuda() +vl_gpt = vl_gpt.to(dtype).to(device) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' @torch.inference_mode() def multimodal_understanding(image_data, question, seed, top_p, temperature): - torch.cuda.empty_cache() + # Clear CUDA cache if using CUDA + if device == 'cuda': + torch.cuda.empty_cache() + + # set seed torch.manual_seed(seed) np.random.seed(seed) - torch.cuda.manual_seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) + elif device == 'mps': + torch.mps.manual_seed(seed) conversation = [ { @@ -43,7 +59,7 @@ def multimodal_understanding(image_data, question, seed, top_p, temperature): pil_images = [Image.open(io.BytesIO(image_data))] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True - ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) + ).to(device, dtype=dtype) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) outputs = vl_gpt.language_model.generate( @@ -84,36 +100,39 @@ def generate(input_ids, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): - torch.cuda.empty_cache() - tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) - for i in range(parallel_size * 2): - tokens[i, :] = input_ids - if i % 2 != 0: - tokens[i, 1:-1] = vl_chat_processor.pad_id - inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) - generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) + try: + torch.cuda.empty_cache() if device == 'cuda' else None + tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device) + for i in range(parallel_size * 2): + tokens[i, :] = input_ids + if i % 2 != 0: + tokens[i, 1:-1] = vl_chat_processor.pad_id + inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) + generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device) - pkv = None - for i in range(image_token_num_per_image): - outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) - pkv = outputs.past_key_values - hidden_states = outputs.last_hidden_state - logits = vl_gpt.gen_head(hidden_states[:, -1, :]) - logit_cond = logits[0::2, :] - logit_uncond = logits[1::2, :] - logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) - probs = torch.softmax(logits / temperature, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - generated_tokens[:, i] = next_token.squeeze(dim=-1) - next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) - img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) - inputs_embeds = img_embeds.unsqueeze(dim=1) - patches = vl_gpt.gen_vision_model.decode_code( - generated_tokens.to(dtype=torch.int), - shape=[parallel_size, 8, width // patch_size, height // patch_size] - ) + pkv = None + for i in range(image_token_num_per_image): + outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) + pkv = outputs.past_key_values + hidden_states = outputs.last_hidden_state + logits = vl_gpt.gen_head(hidden_states[:, -1, :]) + logit_cond = logits[0::2, :] + logit_uncond = logits[1::2, :] + logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) + probs = torch.softmax(logits / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + generated_tokens[:, i] = next_token.squeeze(dim=-1) + next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) + img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) + inputs_embeds = img_embeds.unsqueeze(dim=1) + patches = vl_gpt.gen_vision_model.decode_code( + generated_tokens.to(dtype=torch.int), + shape=[parallel_size, 8, width // patch_size, height // patch_size] + ) - return generated_tokens.to(dtype=torch.int), patches + return generated_tokens.to(dtype=torch.int), patches + except Exception as e: + raise Exception(f"Error in generate function: {str(e)}") def unpack(dec, width, height, parallel_size=5): @@ -128,28 +147,37 @@ def unpack(dec, width, height, parallel_size=5): @torch.inference_mode() def generate_image(prompt, seed, guidance): - torch.cuda.empty_cache() - seed = seed if seed is not None else 12345 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - np.random.seed(seed) - width = 384 - height = 384 - parallel_size = 5 - - with torch.no_grad(): - messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}] - text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( - conversations=messages, - sft_format=vl_chat_processor.sft_format, - system_prompt='' - ) - text = text + vl_chat_processor.image_start_tag - input_ids = torch.LongTensor(tokenizer.encode(text)) - _, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size) - images = unpack(patches, width // 16 * 16, height // 16 * 16) + try: + # Clear CUDA cache if using CUDA + if device == 'cuda': + torch.cuda.empty_cache() + # Set the seed for reproducible results + if seed is not None: + torch.manual_seed(seed) + if device == 'cuda': + torch.cuda.manual_seed(seed) + elif device == 'mps': + torch.mps.manual_seed(seed) + np.random.seed(seed) + width = 384 + height = 384 + parallel_size = 5 + + with torch.no_grad(): + messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}] + text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( + conversations=messages, + sft_format=vl_chat_processor.sft_format, + system_prompt='' + ) + text = text + vl_chat_processor.image_start_tag + input_ids = torch.LongTensor(tokenizer.encode(text)) + _, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size) + images = unpack(patches, width // 16 * 16, height // 16 * 16) - return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)] + return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)] + except Exception as e: + raise Exception(f"Error in generate_image function: {str(e)}") @app.post("/generate_images/")