mirror of
https://github.com/deepseek-ai/DeepSeek-VL.git
synced 2025-04-19 01:59:13 -04:00
add an util function to detect platflorm and suitable dtype
This commit is contained in:
parent
5db8156747
commit
494e622544
@ -27,6 +27,27 @@ from transformers import AutoModelForCausalLM
|
||||
from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
|
||||
|
||||
|
||||
def get_device_and_dtype():
|
||||
"""
|
||||
Get the device and dtype for the model.
|
||||
"""
|
||||
|
||||
if torch.cuda.is_available():
|
||||
print("Using CUDA and BFloat16")
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
elif torch.backends.mps.is_available():
|
||||
print("Using MPS and FP16")
|
||||
device = torch.device("mps")
|
||||
dtype = torch.float16
|
||||
else:
|
||||
print("Using CPU and FP32")
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
return device, dtype
|
||||
|
||||
|
||||
def load_pretrained_model(model_path: str):
|
||||
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
|
Loading…
Reference in New Issue
Block a user