Add output streaming

This commit is contained in:
Luis 2024-03-11 19:52:55 +00:00
parent 53c540ec9a
commit 86b6e851f8

View File

@ -1,12 +1,12 @@
# Prediction interface for Cog ⚙️ # Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md # https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input, Path from cog import BasePredictor, Input, Path, ConcatenateIterator
import os import os
import torch import torch
from threading import Thread from threading import Thread
from transformers import AutoModelForCausalLM
from deepseek_vl.utils.io import load_pil_images from deepseek_vl.utils.io import load_pil_images
from transformers import AutoModelForCausalLM, TextIteratorStreamer
from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM
# Enable faster download speed # Enable faster download speed
@ -34,9 +34,9 @@ class Predictor(BasePredictor):
def predict( def predict(
self, self,
image: Path = Input(description="Input image"), image: Path = Input(description="Input image"),
prompt: str = Input(description="Input prompt", default="Describe the image"), prompt: str = Input(description="Input prompt", default="Describe this image"),
max_new_tokens: int = Input(description="Maximum number of tokens to generate", default=512) max_new_tokens: int = Input(description="Maximum number of tokens to generate", default=512)
) -> str: ) -> ConcatenateIterator[str]:
"""Run a single prediction on the model""" """Run a single prediction on the model"""
conversation = [ conversation = [
{ {
@ -58,20 +58,25 @@ class Predictor(BasePredictor):
force_batchify=True force_batchify=True
).to('cuda') ).to('cuda')
# run image encoder to get the image embeddings streamer = TextIteratorStreamer(
inputs_embeds = self.vl_gpt.prepare_inputs_embeds(**prepare_inputs) self.tokenizer, skip_prompt=True, skip_special_tokens=True
# run the model to get the response
outputs = self.vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=self.tokenizer.eos_token_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True
) )
answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) thread = Thread(
return answer target=self.vl_gpt.language_model.generate,
kwargs={
"inputs_embeds": self.vl_gpt.prepare_inputs_embeds(**prepare_inputs),
"attention_mask": prepare_inputs.attention_mask,
"pad_token_id": self.tokenizer.eos_token_id,
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": max_new_tokens,
"do_sample": False,
"use_cache": True,
"streamer": streamer,
},
)
thread.start()
for new_token in streamer:
yield new_token
thread.join()