mirror of
https://github.com/deepseek-ai/DeepSeek-VL.git
synced 2025-04-20 02:29:06 -04:00
Add output streaming
This commit is contained in:
parent
53c540ec9a
commit
86b6e851f8
43
predict.py
43
predict.py
@ -1,12 +1,12 @@
|
|||||||
# Prediction interface for Cog ⚙️
|
# Prediction interface for Cog ⚙️
|
||||||
# https://github.com/replicate/cog/blob/main/docs/python.md
|
# https://github.com/replicate/cog/blob/main/docs/python.md
|
||||||
|
|
||||||
from cog import BasePredictor, Input, Path
|
from cog import BasePredictor, Input, Path, ConcatenateIterator
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
from deepseek_vl.utils.io import load_pil_images
|
from deepseek_vl.utils.io import load_pil_images
|
||||||
|
from transformers import AutoModelForCausalLM, TextIteratorStreamer
|
||||||
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
|
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
|
||||||
|
|
||||||
# Enable faster download speed
|
# Enable faster download speed
|
||||||
@ -34,9 +34,9 @@ class Predictor(BasePredictor):
|
|||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
image: Path = Input(description="Input image"),
|
image: Path = Input(description="Input image"),
|
||||||
prompt: str = Input(description="Input prompt", default="Describe the image"),
|
prompt: str = Input(description="Input prompt", default="Describe this image"),
|
||||||
max_new_tokens: int = Input(description="Maximum number of tokens to generate", default=512)
|
max_new_tokens: int = Input(description="Maximum number of tokens to generate", default=512)
|
||||||
) -> str:
|
) -> ConcatenateIterator[str]:
|
||||||
"""Run a single prediction on the model"""
|
"""Run a single prediction on the model"""
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
@ -58,20 +58,25 @@ class Predictor(BasePredictor):
|
|||||||
force_batchify=True
|
force_batchify=True
|
||||||
).to('cuda')
|
).to('cuda')
|
||||||
|
|
||||||
# run image encoder to get the image embeddings
|
streamer = TextIteratorStreamer(
|
||||||
inputs_embeds = self.vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
self.tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||||
|
|
||||||
# run the model to get the response
|
|
||||||
outputs = self.vl_gpt.language_model.generate(
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
attention_mask=prepare_inputs.attention_mask,
|
|
||||||
pad_token_id=self.tokenizer.eos_token_id,
|
|
||||||
bos_token_id=self.tokenizer.bos_token_id,
|
|
||||||
eos_token_id=self.tokenizer.eos_token_id,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
do_sample=False,
|
|
||||||
use_cache=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
thread = Thread(
|
||||||
return answer
|
target=self.vl_gpt.language_model.generate,
|
||||||
|
kwargs={
|
||||||
|
"inputs_embeds": self.vl_gpt.prepare_inputs_embeds(**prepare_inputs),
|
||||||
|
"attention_mask": prepare_inputs.attention_mask,
|
||||||
|
"pad_token_id": self.tokenizer.eos_token_id,
|
||||||
|
"bos_token_id": self.tokenizer.bos_token_id,
|
||||||
|
"eos_token_id": self.tokenizer.eos_token_id,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
"do_sample": False,
|
||||||
|
"use_cache": True,
|
||||||
|
"streamer": streamer,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
for new_token in streamer:
|
||||||
|
yield new_token
|
||||||
|
thread.join()
|
||||||
|
Loading…
Reference in New Issue
Block a user