add an util function to detect platflorm and suitable dtype

This commit is contained in:
Nicola Dall'Asen 2024-03-13 12:31:28 +01:00
parent 5db8156747
commit 494e622544

View File

@ -27,6 +27,27 @@ from transformers import AutoModelForCausalLM
from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
def get_device_and_dtype():
"""
Get the device and dtype for the model.
"""
if torch.cuda.is_available():
print("Using CUDA and BFloat16")
device = torch.device("cuda")
dtype = torch.bfloat16
elif torch.backends.mps.is_available():
print("Using MPS and FP16")
device = torch.device("mps")
dtype = torch.float16
else:
print("Using CPU and FP32")
device = torch.device("cpu")
dtype = torch.float32
return device, dtype
def load_pretrained_model(model_path: str):
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer