mirror of
https://github.com/deepseek-ai/DeepSeek-VL.git
synced 2025-04-19 10:09:09 -04:00
load model on detected device and use correct dtype
This commit is contained in:
parent
494e622544
commit
40ec6491d4
@ -8,7 +8,7 @@ from threading import Thread
|
|||||||
import torch
|
import torch
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from deepseek_vl.utils.io import load_pretrained_model
|
from deepseek_vl.utils.io import load_pretrained_model, get_device_and_dtype
|
||||||
|
|
||||||
|
|
||||||
def load_image(image_file):
|
def load_image(image_file):
|
||||||
@ -34,13 +34,13 @@ def get_help_message(image_token):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):
|
||||||
|
_, dtype = get_device_and_dtype()
|
||||||
prompt = conv.get_prompt()
|
prompt = conv.get_prompt()
|
||||||
prepare_inputs = vl_chat_processor.__call__(
|
prepare_inputs = vl_chat_processor.__call__(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
images=pil_images,
|
images=pil_images,
|
||||||
force_batchify=True
|
force_batchify=True
|
||||||
).to(vl_gpt.device)
|
).to(vl_gpt.device, dtype=dtype)
|
||||||
|
|
||||||
# run image encoder to get the image embeddings
|
# run image encoder to get the image embeddings
|
||||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||||
|
@ -52,10 +52,12 @@ def load_pretrained_model(model_path: str):
|
|||||||
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
||||||
tokenizer = vl_chat_processor.tokenizer
|
tokenizer = vl_chat_processor.tokenizer
|
||||||
|
|
||||||
|
device, dtype = get_device_and_dtype()
|
||||||
|
|
||||||
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
|
||||||
model_path, trust_remote_code=True
|
model_path, trust_remote_code=True
|
||||||
)
|
)
|
||||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
vl_gpt = vl_gpt.to(device, dtype=dtype).eval()
|
||||||
|
|
||||||
return tokenizer, vl_chat_processor, vl_gpt
|
return tokenizer, vl_chat_processor, vl_gpt
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
|
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
|
||||||
from deepseek_vl.utils.io import load_pil_images
|
from deepseek_vl.utils.io import load_pil_images, get_device_and_dtype
|
||||||
|
|
||||||
|
|
||||||
# specify the path to the model
|
# specify the path to the model
|
||||||
@ -11,7 +11,9 @@ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
|||||||
tokenizer = vl_chat_processor.tokenizer
|
tokenizer = vl_chat_processor.tokenizer
|
||||||
|
|
||||||
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
||||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
|
||||||
|
device, dtype = get_device_and_dtype()
|
||||||
|
vl_gpt = vl_gpt.to(dtype).to(device).eval()
|
||||||
|
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
@ -32,7 +34,7 @@ prepare_inputs = vl_chat_processor(
|
|||||||
conversations=conversation,
|
conversations=conversation,
|
||||||
images=pil_images,
|
images=pil_images,
|
||||||
force_batchify=True
|
force_batchify=True
|
||||||
).to(vl_gpt.device)
|
).to(vl_gpt.device, dtype=dtype)
|
||||||
|
|
||||||
# run image encoder to get the image embeddings
|
# run image encoder to get the image embeddings
|
||||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user