modify example in README to use the new device and dtype function

This commit is contained in:
Nicola Dall'Asen 2024-03-13 12:33:52 +01:00
parent 40ec6491d4
commit 4eac5b5998

View File

@ -112,7 +112,7 @@ import torch
from transformers import AutoModelForCausalLM
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
from deepseek_vl.utils.io import load_pil_images, get_device_and_dtype
# specify the path to the model
@ -121,7 +121,9 @@ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
device, dtype = get_device_and_dtype()
vl_gpt = vl_gpt.to(dtype).to(device).eval()
conversation = [
{