From 494e622544b7aefd5634fdc72878a0289d4e548f Mon Sep 17 00:00:00 2001 From: Nicola Dall'Asen <fodark@pm.me> Date: Wed, 13 Mar 2024 12:31:28 +0100 Subject: [PATCH] add an util function to detect platflorm and suitable dtype --- deepseek_vl/utils/io.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/deepseek_vl/utils/io.py b/deepseek_vl/utils/io.py index 081f7a2..a160f00 100644 --- a/deepseek_vl/utils/io.py +++ b/deepseek_vl/utils/io.py @@ -27,6 +27,27 @@ from transformers import AutoModelForCausalLM from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor +def get_device_and_dtype(): + """ + Get the device and dtype for the model. + """ + + if torch.cuda.is_available(): + print("Using CUDA and BFloat16") + device = torch.device("cuda") + dtype = torch.bfloat16 + elif torch.backends.mps.is_available(): + print("Using MPS and FP16") + device = torch.device("mps") + dtype = torch.float16 + else: + print("Using CPU and FP32") + device = torch.device("cpu") + dtype = torch.float32 + + return device, dtype + + def load_pretrained_model(model_path: str): vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer