load model on detected device and use correct dtype

This commit is contained in:
Nicola Dall'Asen 2024-03-13 12:32:14 +01:00
parent 494e622544
commit 40ec6491d4
3 changed files with 11 additions and 7 deletions

View File

@ -8,7 +8,7 @@ from threading import Thread
import torch
from transformers import TextIteratorStreamer
from deepseek_vl.utils.io import load_pretrained_model
from deepseek_vl.utils.io import load_pretrained_model, get_device_and_dtype
def load_image(image_file):
@ -34,13 +34,13 @@ def get_help_message(image_token):
@torch.inference_mode()
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):
_, dtype = get_device_and_dtype()
prompt = conv.get_prompt()
prepare_inputs = vl_chat_processor.__call__(
prompt=prompt,
images=pil_images,
force_batchify=True
).to(vl_gpt.device)
).to(vl_gpt.device, dtype=dtype)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

View File

@ -52,10 +52,12 @@ def load_pretrained_model(model_path: str):
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
device, dtype = get_device_and_dtype()
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
vl_gpt = vl_gpt.to(device, dtype=dtype).eval()
return tokenizer, vl_chat_processor, vl_gpt

View File

@ -2,7 +2,7 @@ import torch
from transformers import AutoModelForCausalLM
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
from deepseek_vl.utils.io import load_pil_images
from deepseek_vl.utils.io import load_pil_images, get_device_and_dtype
# specify the path to the model
@ -11,7 +11,9 @@ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
device, dtype = get_device_and_dtype()
vl_gpt = vl_gpt.to(dtype).to(device).eval()
conversation = [
{
@ -32,7 +34,7 @@ prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True
).to(vl_gpt.device)
).to(vl_gpt.device, dtype=dtype)
# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)