mirror of
https://github.com/deepseek-ai/DeepSeek-VL.git
synced 2025-04-19 18:19:03 -04:00
modify example in README to use the new device and dtype function
This commit is contained in:
parent
40ec6491d4
commit
4eac5b5998
@ -112,7 +112,7 @@ import torch
|
|||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
|
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
|
||||||
from deepseek_vl.utils.io import load_pil_images
|
from deepseek_vl.utils.io import load_pil_images, get_device_and_dtype
|
||||||
|
|
||||||
|
|
||||||
# specify the path to the model
|
# specify the path to the model
|
||||||
@ -121,7 +121,9 @@ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
|
|||||||
tokenizer = vl_chat_processor.tokenizer
|
tokenizer = vl_chat_processor.tokenizer
|
||||||
|
|
||||||
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
||||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
|
|
||||||
|
device, dtype = get_device_and_dtype()
|
||||||
|
vl_gpt = vl_gpt.to(dtype).to(device).eval()
|
||||||
|
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user