mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-23 06:08:59 -05:00
更新CUDA设备管理逻辑,支持MPS设备
This commit is contained in:
parent
581fdd1489
commit
eb441468d8
@ -19,7 +19,7 @@ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
|||||||
|
|
||||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||||
tokenizer = vl_chat_processor.tokenizer
|
tokenizer = vl_chat_processor.tokenizer
|
||||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
cuda_device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||||||
# Multimodal Understanding function
|
# Multimodal Understanding function
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
# Multimodal Understanding function
|
# Multimodal Understanding function
|
||||||
|
@ -5,7 +5,13 @@ from PIL import Image
|
|||||||
from diffusers.models import AutoencoderKL
|
from diffusers.models import AutoencoderKL
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
cuda_device = 'cpu'
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
cuda_device = 'cuda'
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
cuda_device = 'mps'
|
||||||
|
else:
|
||||||
|
cuda_device = 'cpu'
|
||||||
|
|
||||||
# Load model and processor
|
# Load model and processor
|
||||||
model_path = "deepseek-ai/JanusFlow-1.3B"
|
model_path = "deepseek-ai/JanusFlow-1.3B"
|
||||||
|
@ -21,7 +21,7 @@ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
|||||||
|
|
||||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||||
tokenizer = vl_chat_processor.tokenizer
|
tokenizer = vl_chat_processor.tokenizer
|
||||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
cuda_device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
Loading…
Reference in New Issue
Block a user