mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-04-20 02:28:58 -04:00
fix: image generating in demos
This commit is contained in:
parent
c1876efe83
commit
17bc41245a
@ -9,9 +9,9 @@ import numpy as np
|
|||||||
# Device and dtype configuration
|
# Device and dtype configuration
|
||||||
def get_device_and_dtype():
|
def get_device_and_dtype():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return 'cuda', torch.bfloat16
|
return 'cuda', torch.float32
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
return 'mps', torch.float16
|
return 'mps', torch.float32
|
||||||
return 'cpu', torch.float32
|
return 'cpu', torch.float32
|
||||||
|
|
||||||
device, dtype = get_device_and_dtype()
|
device, dtype = get_device_and_dtype()
|
||||||
|
@ -15,10 +15,10 @@ if torch.cuda.is_available():
|
|||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
device = 'mps'
|
device = 'mps'
|
||||||
dtype = torch.float16
|
dtype = torch.float32 # MPS设备使用float32
|
||||||
else:
|
else:
|
||||||
device = 'cpu'
|
device = 'cpu'
|
||||||
dtype = torch.float16
|
dtype = torch.float32 # CPU设备使用float32
|
||||||
|
|
||||||
# Load model and processor
|
# Load model and processor
|
||||||
model_path = "deepseek-ai/Janus-Pro-7B"
|
model_path = "deepseek-ai/Janus-Pro-7B"
|
||||||
|
Loading…
Reference in New Issue
Block a user