From 581fdd1489c1518889afad3abb543c419993d633 Mon Sep 17 00:00:00 2001 From: censujiang Date: Fri, 31 Jan 2025 03:51:33 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96CUDA=E8=AE=BE=E5=A4=87?= =?UTF-8?q?=E7=AE=A1=E7=90=86=EF=BC=8C=E7=A1=AE=E4=BF=9D=E5=9C=A8=E5=8F=AF?= =?UTF-8?q?=E7=94=A8=E6=97=B6=E6=B8=85=E7=90=86CUDA=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 ++++- demo/app_januspro.py | 21 +++++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 229e98c..1adfe8b 100644 --- a/.gitignore +++ b/.gitignore @@ -418,4 +418,7 @@ tags [._]*.un~ .vscode .github -generated_samples/ \ No newline at end of file +generated_samples/ + +# gradio +.gradio/ \ No newline at end of file diff --git a/demo/app_januspro.py b/demo/app_januspro.py index 702e58e..ef748d3 100644 --- a/demo/app_januspro.py +++ b/demo/app_januspro.py @@ -21,24 +21,30 @@ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() + cuda_device = 'cuda' +elif torch.backends.mps.is_available(): + vl_gpt = vl_gpt.to(torch.float16).to('mps') + cuda_device = 'mps' else: vl_gpt = vl_gpt.to(torch.float16) + cuda_device = 'cpu' vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' @torch.inference_mode() # @spaces.GPU(duration=120) # Multimodal Understanding function def multimodal_understanding(image, question, seed, top_p, temperature): # Clear CUDA cache before generating - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # set seed torch.manual_seed(seed) np.random.seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) conversation = [ { @@ -83,7 +89,8 @@ def generate(input_ids, image_token_num_per_image: int = 576, patch_size: int = 16): # Clear CUDA cache before generating - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): @@ -138,11 +145,13 @@ def generate_image(prompt, guidance=5, t2i_temperature=1.0): # Clear CUDA cache and avoid tracking gradients - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Set the seed for reproducible results if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) np.random.seed(seed) width = 384 height = 384