diff --git a/README.md b/README.md
index 5ff7cf9..ac67499 100644
--- a/README.md
+++ b/README.md
@@ -43,6 +43,9 @@
   <a href="LICENSE-MODEL">
     <img alt="Model License" src="https://img.shields.io/badge/Model_License-Model_Agreement-f5de53?&color=f5de53">
   </a>
+    <a href="https://replicate.com/chenxwh/deepseek-vl2" target="_blank">
+    <img src=https://replicate.com/chenxwh/deepseek-vl2/badge/>
+  </a>
 </div>
 
 
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000..ce30ad5
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,36 @@
+# Configuration for Cog ⚙️
+# Reference: https://cog.run/yaml
+
+build:
+  # set to true if your model requires a GPU
+  gpu: true
+
+  # a list of ubuntu apt packages to install
+  system_packages:
+    - "libgl1-mesa-glx"
+    - "libglib2.0-0"
+
+  # python version in the form '3.11' or '3.11.4'
+  python_version: "3.11"
+
+  # a list of packages in the format <package-name>==<version>
+  python_packages:
+    - torch==2.4.0
+    - transformers<4.42
+    - numpy
+    - gradio==3.48.0
+    - timm>=0.9.16
+    - accelerate
+    - sentencepiece
+    - attrdict
+    - einops
+    - xformers
+    - ipython
+    - joblib
+    - mdtex2html
+
+  # commands run after the environment is setup
+  run:
+    - pip install -U flash-attn --no-build-isolation
+    - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
+predict: "predict.py:Predictor"
diff --git a/predict.py b/predict.py
new file mode 100644
index 0000000..570ed80
--- /dev/null
+++ b/predict.py
@@ -0,0 +1,148 @@
+# Prediction interface for Cog ⚙️
+# https://cog.run/python
+
+import os
+import subprocess
+import time
+from typing import Optional
+from cog import BasePredictor, Input, Path, BaseModel
+import torch
+from PIL import Image
+from deepseek_vl2.serve.app_modules.utils import parse_ref_bbox
+from deepseek_vl2.serve.inference import (
+    convert_conversation_to_prompts,
+    load_model,
+)
+from web_demo import generate_prompt_with_history
+
+
+MODEL_CACHE = "model_cache"
+MODEL_URL = f"https://weights.replicate.delivery/default/deepseek-ai/deepseek-vl2-small/model_cache.tar"
+
+
+def download_weights(url, dest):
+    start = time.time()
+    print("downloading url: ", url)
+    print("downloading to: ", dest)
+    subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
+    print("downloading took: ", time.time() - start)
+
+
+class ModelOutput(BaseModel):
+    img_out: Optional[Path]
+    text_out: str
+
+
+class Predictor(BasePredictor):
+    def setup(self) -> None:
+        """Load the model into memory to make running multiple predictions efficient"""
+
+        if not os.path.exists(MODEL_CACHE):
+            print("downloading")
+            download_weights(MODEL_URL, MODEL_CACHE)
+
+        self.dtype = torch.bfloat16
+        self.tokenizer, self.vl_gpt, self.vl_chat_processor = load_model(
+            f"{MODEL_CACHE}/deepseek-ai/deepseek-vl2-small", dtype=self.dtype
+        )
+
+    def predict(
+        self,
+        text: str = Input(
+            description="Input text.",
+            default="Describe this image.",
+        ),
+        image1: Path = Input(description="First image"),
+        image2: Path = Input(
+            description="Optional, second image for multiple images image2text",
+            default=None,
+        ),
+        image3: Path = Input(
+            description="Optional, third image for multiple images image2text",
+            default=None,
+        ),
+        max_new_tokens: int = Input(
+            description="The maximum numbers of tokens to generate",
+            le=4096,
+            ge=0,
+            default=2048,
+        ),
+        temperature: float = Input(
+            description="The value used to modulate the probabilities of the next token. Set the temperature to 0 for deterministic generation",
+            default=0.1,
+        ),
+        top_p: float = Input(
+            description="If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
+            default=0.9,
+        ),
+        repetition_penalty: float = Input(
+            description="Repetition penalty", le=2, ge=0, default=1.1
+        ),
+    ) -> ModelOutput:
+        """Run a single prediction on the model"""
+
+        pil_images = [
+            Image.open(str(img)).convert("RGB")
+            for img in [image1, image2, image3]
+            if img
+        ]
+
+        conversation = generate_prompt_with_history(
+            text,
+            pil_images,
+            None,
+            self.vl_chat_processor,
+            self.tokenizer,
+            max_length=4096,
+        )
+
+        all_conv, _ = convert_conversation_to_prompts(conversation)
+        print(all_conv)
+
+        prepare_inputs = self.vl_chat_processor(
+            conversations=all_conv,
+            images=pil_images,
+            force_batchify=True,
+        ).to(self.vl_gpt.device, dtype=self.dtype)
+
+        with torch.no_grad():
+            inputs_embeds, past_key_values = self.vl_gpt.incremental_prefilling(
+                input_ids=prepare_inputs.input_ids,
+                images=prepare_inputs.images,
+                images_seq_mask=prepare_inputs.images_seq_mask,
+                images_spatial_crop=prepare_inputs.images_spatial_crop,
+                attention_mask=prepare_inputs.attention_mask,
+            )
+
+            outputs = self.vl_gpt.generate(
+                inputs_embeds=inputs_embeds,
+                input_ids=prepare_inputs.input_ids,
+                images=prepare_inputs.images,
+                images_seq_mask=prepare_inputs.images_seq_mask,
+                images_spatial_crop=prepare_inputs.images_spatial_crop,
+                attention_mask=prepare_inputs.attention_mask,
+                past_key_values=past_key_values,
+                pad_token_id=self.tokenizer.eos_token_id,
+                bos_token_id=self.tokenizer.bos_token_id,
+                eos_token_id=self.tokenizer.eos_token_id,
+                max_new_tokens=max_new_tokens,
+                use_cache=True,
+                do_sample=temperature > 0,
+                temperature=temperature,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+            )
+
+            answer = self.tokenizer.decode(
+                outputs[0][len(prepare_inputs.input_ids[0]) :].cpu().tolist(),
+                skip_special_tokens=False,
+            )
+            vg_image = parse_ref_bbox(answer, image=pil_images[-1])
+
+            out_img = "out.png"
+            if vg_image is not None:
+                vg_image.save(out_img, format="JPEG", quality=85)
+
+        return ModelOutput(
+            text_out=answer, img_out=Path(out_img) if vg_image is not None else None
+        )