diff --git a/deepseek_vl/utils/io.py b/deepseek_vl/utils/io.py index 081f7a2..a160f00 100644 --- a/deepseek_vl/utils/io.py +++ b/deepseek_vl/utils/io.py @@ -27,6 +27,27 @@ from transformers import AutoModelForCausalLM from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor +def get_device_and_dtype(): + """ + Get the device and dtype for the model. + """ + + if torch.cuda.is_available(): + print("Using CUDA and BFloat16") + device = torch.device("cuda") + dtype = torch.bfloat16 + elif torch.backends.mps.is_available(): + print("Using MPS and FP16") + device = torch.device("mps") + dtype = torch.float16 + else: + print("Using CPU and FP32") + device = torch.device("cpu") + dtype = torch.float32 + + return device, dtype + + def load_pretrained_model(model_path: str): vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer