From eb441468d8bd1d5a37017041464ddf57fba4ddd2 Mon Sep 17 00:00:00 2001 From: censujiang Date: Fri, 31 Jan 2025 04:01:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0CUDA=E8=AE=BE=E5=A4=87?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?MPS=E8=AE=BE=E5=A4=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- demo/app.py | 2 +- demo/app_janusflow.py | 8 +++++++- demo/fastapi_app.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/demo/app.py b/demo/app.py index 7dbc59f..b0918f1 100644 --- a/demo/app.py +++ b/demo/app.py @@ -19,7 +19,7 @@ vl_gpt = vl_gpt.to(torch.bfloat16).cuda() vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' +cuda_device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' # Multimodal Understanding function @torch.inference_mode() # Multimodal Understanding function diff --git a/demo/app_janusflow.py b/demo/app_janusflow.py index 4777196..01f3b70 100644 --- a/demo/app_janusflow.py +++ b/demo/app_janusflow.py @@ -5,7 +5,13 @@ from PIL import Image from diffusers.models import AutoencoderKL import numpy as np -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' +cuda_device = 'cpu' +if torch.cuda.is_available(): + cuda_device = 'cuda' +elif torch.backends.mps.is_available(): + cuda_device = 'mps' +else: + cuda_device = 'cpu' # Load model and processor model_path = "deepseek-ai/JanusFlow-1.3B" diff --git a/demo/fastapi_app.py b/demo/fastapi_app.py index c2e5710..a789b61 100644 --- a/demo/fastapi_app.py +++ b/demo/fastapi_app.py @@ -21,7 +21,7 @@ vl_gpt = vl_gpt.to(torch.bfloat16).cuda() vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer -cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' +cuda_device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' @torch.inference_mode()