更新CUDA设备管理逻辑,支持MPS设备

This commit is contained in:
censujiang 2025-01-31 04:01:37 +08:00
parent 581fdd1489
commit eb441468d8
3 changed files with 9 additions and 3 deletions

View File

@ -19,7 +19,7 @@ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
cuda_device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
# Multimodal Understanding function
@torch.inference_mode()
# Multimodal Understanding function

View File

@ -5,7 +5,13 @@ from PIL import Image
from diffusers.models import AutoencoderKL
import numpy as np
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
cuda_device = 'cpu'
if torch.cuda.is_available():
cuda_device = 'cuda'
elif torch.backends.mps.is_available():
cuda_device = 'mps'
else:
cuda_device = 'cpu'
# Load model and processor
model_path = "deepseek-ai/JanusFlow-1.3B"

View File

@ -21,7 +21,7 @@ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
cuda_device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
@torch.inference_mode()