mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-22 13:48:57 -05:00
更新CUDA设备管理逻辑,支持MPS设备
This commit is contained in:
parent
581fdd1489
commit
eb441468d8
@ -19,7 +19,7 @@ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||
|
||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||||
# Multimodal Understanding function
|
||||
@torch.inference_mode()
|
||||
# Multimodal Understanding function
|
||||
|
@ -5,7 +5,13 @@ from PIL import Image
|
||||
from diffusers.models import AutoencoderKL
|
||||
import numpy as np
|
||||
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
cuda_device = 'cpu'
|
||||
if torch.cuda.is_available():
|
||||
cuda_device = 'cuda'
|
||||
elif torch.backends.mps.is_available():
|
||||
cuda_device = 'mps'
|
||||
else:
|
||||
cuda_device = 'cpu'
|
||||
|
||||
# Load model and processor
|
||||
model_path = "deepseek-ai/JanusFlow-1.3B"
|
||||
|
@ -21,7 +21,7 @@ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||
|
||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
Loading…
Reference in New Issue
Block a user