This commit is contained in:
Luis Catacora 2025-02-12 15:08:23 -05:00 committed by GitHub
commit ba8997f076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 126 additions and 0 deletions

View File

@ -21,6 +21,9 @@
<a href="https://huggingface.co/deepseek-ai" target="_blank">
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
</a>
<a href="https://replicate.com/deepseek-ai" target="_blank_">
<img src="https://replicate.com/deepseek-ai/janus-pro-7b/badge" alt="Replicate"/>
</a>
</div>

25
demo/cog.yaml Normal file
View File

@ -0,0 +1,25 @@
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml
build:
gpu: true
cuda: "12.1"
python_version: "3.11"
python_packages:
- "torch==2.2"
- "torchvision"
- "transformers==4.36.2"
- "accelerate==1.3.0"
- "diffusers==0.32.2"
- "opencv-python==4.10.0.84"
- "attrdict==2.0.1"
- "timm==1.0.14"
- "sentencepiece==0.2.0"
- "einops==0.8.0"
- "pillow==10.2.0"
- "numpy==1.24.3"
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
predict: "predict.py:Predictor"

98
demo/predict.py Normal file
View File

@ -0,0 +1,98 @@
from cog import BasePredictor, Input, Path
import os
import time
import torch
import subprocess
import numpy as np
from PIL import Image
from janus.models import VLChatProcessor
from transformers import AutoConfig, AutoModelForCausalLM
MODEL_CACHE = "checkpoints"
# MODEL_URL = "https://weights.replicate.delivery/default/deepseek-ai/Janus-Pro-1B/model.tar"
MODEL_URL = "https://weights.replicate.delivery/default/deepseek-ai/Janus-Pro-7B/model.tar"
def download_weights(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
if not os.path.exists(MODEL_CACHE):
download_weights(MODEL_URL, MODEL_CACHE)
config = AutoConfig.from_pretrained(MODEL_CACHE)
language_config = config.language_config
language_config._attn_implementation = 'eager'
self.vl_gpt = AutoModelForCausalLM.from_pretrained(
MODEL_CACHE,
language_config=language_config,
trust_remote_code=True
)
if torch.cuda.is_available():
self.vl_gpt = self.vl_gpt.to(torch.bfloat16).cuda()
else:
self.vl_gpt = self.vl_gpt.to(torch.float16)
self.vl_chat_processor = VLChatProcessor.from_pretrained(MODEL_CACHE)
self.tokenizer = self.vl_chat_processor.tokenizer
self.cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
@torch.inference_mode()
def predict(
self,
image: Path = Input(description="Input image for multimodal understanding"),
question: str = Input(description="Question about the image"),
seed: int = Input(description="Random seed for reproducibility", default=42),
top_p: float = Input(description="Top-p sampling value", default=0.95, ge=0, le=1),
temperature: float = Input(description="Temperature for text generation", default=0.1, ge=0, le=1),
) -> str:
"""Run a single prediction on the model"""
# Set seed
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
# Load and process image
pil_image = Image.open(image)
image_array = np.array(pil_image)
conversation = [
{
"role": "<|User|>",
"content": f"<image_placeholder>\n{question}",
"images": [image_array],
},
{"role": "<|Assistant|>", "content": ""},
]
pil_images = [pil_image]
prepare_inputs = self.vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True
).to(self.cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
inputs_embeds = self.vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = self.vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=self.tokenizer.eos_token_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False if temperature == 0 else True,
use_cache=True,
temperature=temperature,
top_p=top_p,
)
answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer