diff --git a/.gitignore b/.gitignore index 229e98c..1adfe8b 100644 --- a/.gitignore +++ b/.gitignore @@ -418,4 +418,7 @@ tags [._]*.un~ .vscode .github -generated_samples/ \ No newline at end of file +generated_samples/ + +# gradio +.gradio/ \ No newline at end of file diff --git a/demo/app_januspro.py b/demo/app_januspro.py index 702e58e..ef748d3 100644 --- a/demo/app_januspro.py +++ b/demo/app_januspro.py @@ -21,24 +21,30 @@ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() + cuda_device = 'cuda' +elif torch.backends.mps.is_available(): + vl_gpt = vl_gpt.to(torch.float16).to('mps') + cuda_device = 'mps' else: vl_gpt = vl_gpt.to(torch.float16) + cuda_device = 'cpu' vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' @torch.inference_mode() # @spaces.GPU(duration=120) # Multimodal Understanding function def multimodal_understanding(image, question, seed, top_p, temperature): # Clear CUDA cache before generating - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # set seed torch.manual_seed(seed) np.random.seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) conversation = [ { @@ -83,7 +89,8 @@ def generate(input_ids, image_token_num_per_image: int = 576, patch_size: int = 16): # Clear CUDA cache before generating - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): @@ -138,11 +145,13 @@ def generate_image(prompt, guidance=5, t2i_temperature=1.0): # Clear CUDA cache and avoid tracking gradients - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Set the seed for reproducible results if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) np.random.seed(seed) width = 384 height = 384