diff --git a/README.md b/README.md index 6d24eaf..9bfaeeb 100644 --- a/README.md +++ b/README.md @@ -112,7 +112,7 @@ import torch from transformers import AutoModelForCausalLM from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM -from deepseek_vl.utils.io import load_pil_images +from deepseek_vl.utils.io import load_pil_images, get_device_and_dtype # specify the path to the model @@ -121,7 +121,9 @@ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) -vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() + +device, dtype = get_device_and_dtype() +vl_gpt = vl_gpt.to(dtype).to(device).eval() conversation = [ {