mirror of
https://github.com/deepseek-ai/DeepSeek-VL.git
synced 2025-04-19 10:09:09 -04:00
add an util function to detect platflorm and suitable dtype
This commit is contained in:
parent
5db8156747
commit
494e622544
@ -27,6 +27,27 @@ from transformers import AutoModelForCausalLM
|
|||||||
from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
|
from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_and_dtype():
|
||||||
|
"""
|
||||||
|
Get the device and dtype for the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print("Using CUDA and BFloat16")
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
print("Using MPS and FP16")
|
||||||
|
device = torch.device("mps")
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
print("Using CPU and FP32")
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
return device, dtype
|
||||||
|
|
||||||
|
|
||||||
def load_pretrained_model(model_path: str):
|
def load_pretrained_model(model_path: str):
|
||||||
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
||||||
tokenizer = vl_chat_processor.tokenizer
|
tokenizer = vl_chat_processor.tokenizer
|
||||||
|
Loading…
Reference in New Issue
Block a user