diff --git a/predict.py b/predict.py index ab3115e..4954973 100644 --- a/predict.py +++ b/predict.py @@ -1,12 +1,12 @@ # Prediction interface for Cog ⚙️ # https://github.com/replicate/cog/blob/main/docs/python.md -from cog import BasePredictor, Input, Path +from cog import BasePredictor, Input, Path, ConcatenateIterator import os import torch from threading import Thread -from transformers import AutoModelForCausalLM from deepseek_vl.utils.io import load_pil_images +from transformers import AutoModelForCausalLM, TextIteratorStreamer from deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM # Enable faster download speed @@ -34,9 +34,9 @@ class Predictor(BasePredictor): def predict( self, image: Path = Input(description="Input image"), - prompt: str = Input(description="Input prompt", default="Describe the image"), + prompt: str = Input(description="Input prompt", default="Describe this image"), max_new_tokens: int = Input(description="Maximum number of tokens to generate", default=512) - ) -> str: + ) -> ConcatenateIterator[str]: """Run a single prediction on the model""" conversation = [ { @@ -57,21 +57,26 @@ class Predictor(BasePredictor): images=pil_images, force_batchify=True ).to('cuda') - - # run image encoder to get the image embeddings - inputs_embeds = self.vl_gpt.prepare_inputs_embeds(**prepare_inputs) - - # run the model to get the response - outputs = self.vl_gpt.language_model.generate( - inputs_embeds=inputs_embeds, - attention_mask=prepare_inputs.attention_mask, - pad_token_id=self.tokenizer.eos_token_id, - bos_token_id=self.tokenizer.bos_token_id, - eos_token_id=self.tokenizer.eos_token_id, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True + + streamer = TextIteratorStreamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True ) - - answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) - return answer + + thread = Thread( + target=self.vl_gpt.language_model.generate, + kwargs={ + "inputs_embeds": self.vl_gpt.prepare_inputs_embeds(**prepare_inputs), + "attention_mask": prepare_inputs.attention_mask, + "pad_token_id": self.tokenizer.eos_token_id, + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "max_new_tokens": max_new_tokens, + "do_sample": False, + "use_cache": True, + "streamer": streamer, + }, + ) + thread.start() + for new_token in streamer: + yield new_token + thread.join()