From 494e622544b7aefd5634fdc72878a0289d4e548f Mon Sep 17 00:00:00 2001
From: Nicola Dall'Asen <fodark@pm.me>
Date: Wed, 13 Mar 2024 12:31:28 +0100
Subject: [PATCH] add an util function to detect platflorm and suitable dtype

---
 deepseek_vl/utils/io.py | 21 +++++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/deepseek_vl/utils/io.py b/deepseek_vl/utils/io.py
index 081f7a2..a160f00 100644
--- a/deepseek_vl/utils/io.py
+++ b/deepseek_vl/utils/io.py
@@ -27,6 +27,27 @@ from transformers import AutoModelForCausalLM
 from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor
 
 
+def get_device_and_dtype():
+    """
+    Get the device and dtype for the model.
+    """
+
+    if torch.cuda.is_available():
+        print("Using CUDA and BFloat16")
+        device = torch.device("cuda")
+        dtype = torch.bfloat16
+    elif torch.backends.mps.is_available():
+        print("Using MPS and FP16")
+        device = torch.device("mps")
+        dtype = torch.float16
+    else:
+        print("Using CPU and FP32")
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    return device, dtype
+
+
 def load_pretrained_model(model_path: str):
     vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
     tokenizer = vl_chat_processor.tokenizer