mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-22 21:58:59 -05:00
Add Replicate Demos
This commit is contained in:
parent
1daa72fa40
commit
12856067bf
@ -21,6 +21,9 @@
|
||||
<a href="https://huggingface.co/deepseek-ai" target="_blank">
|
||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
|
||||
</a>
|
||||
<a href="https://replicate.com/deepseek-ai" target="_blank_">
|
||||
<img src="https://replicate.com/deepseek-ai/janus-pro-7b/badge" alt="Replicate"/>
|
||||
</a>
|
||||
|
||||
</div>
|
||||
|
||||
|
25
demo/cog.yaml
Normal file
25
demo/cog.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
# Configuration for Cog ⚙️
|
||||
# Reference: https://cog.run/yaml
|
||||
|
||||
build:
|
||||
gpu: true
|
||||
cuda: "12.1"
|
||||
python_version: "3.11"
|
||||
python_packages:
|
||||
- "torch==2.2"
|
||||
- "torchvision"
|
||||
- "transformers==4.36.2"
|
||||
- "accelerate==1.3.0"
|
||||
- "diffusers==0.32.2"
|
||||
- "opencv-python==4.10.0.84"
|
||||
- "attrdict==2.0.1"
|
||||
- "timm==1.0.14"
|
||||
- "sentencepiece==0.2.0"
|
||||
- "einops==0.8.0"
|
||||
- "pillow==10.2.0"
|
||||
- "numpy==1.24.3"
|
||||
|
||||
run:
|
||||
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
|
||||
|
||||
predict: "predict.py:Predictor"
|
98
demo/predict.py
Normal file
98
demo/predict.py
Normal file
@ -0,0 +1,98 @@
|
||||
from cog import BasePredictor, Input, Path
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import subprocess
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from janus.models import VLChatProcessor
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
MODEL_CACHE = "checkpoints"
|
||||
# MODEL_URL = "https://weights.replicate.delivery/default/deepseek-ai/Janus-Pro-1B/model.tar"
|
||||
MODEL_URL = "https://weights.replicate.delivery/default/deepseek-ai/Janus-Pro-7B/model.tar"
|
||||
|
||||
def download_weights(url, dest):
|
||||
start = time.time()
|
||||
print("downloading url: ", url)
|
||||
print("downloading to: ", dest)
|
||||
subprocess.check_call(["pget", "-xf", url, dest], close_fds=False)
|
||||
print("downloading took: ", time.time() - start)
|
||||
|
||||
class Predictor(BasePredictor):
|
||||
def setup(self):
|
||||
"""Load the model into memory to make running multiple predictions efficient"""
|
||||
if not os.path.exists(MODEL_CACHE):
|
||||
download_weights(MODEL_URL, MODEL_CACHE)
|
||||
|
||||
config = AutoConfig.from_pretrained(MODEL_CACHE)
|
||||
language_config = config.language_config
|
||||
language_config._attn_implementation = 'eager'
|
||||
|
||||
self.vl_gpt = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_CACHE,
|
||||
language_config=language_config,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.vl_gpt = self.vl_gpt.to(torch.bfloat16).cuda()
|
||||
else:
|
||||
self.vl_gpt = self.vl_gpt.to(torch.float16)
|
||||
|
||||
self.vl_chat_processor = VLChatProcessor.from_pretrained(MODEL_CACHE)
|
||||
self.tokenizer = self.vl_chat_processor.tokenizer
|
||||
self.cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
@torch.inference_mode()
|
||||
def predict(
|
||||
self,
|
||||
image: Path = Input(description="Input image for multimodal understanding"),
|
||||
question: str = Input(description="Question about the image"),
|
||||
seed: int = Input(description="Random seed for reproducibility", default=42),
|
||||
top_p: float = Input(description="Top-p sampling value", default=0.95, ge=0, le=1),
|
||||
temperature: float = Input(description="Temperature for text generation", default=0.1, ge=0, le=1),
|
||||
) -> str:
|
||||
"""Run a single prediction on the model"""
|
||||
# Set seed
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Load and process image
|
||||
pil_image = Image.open(image)
|
||||
image_array = np.array(pil_image)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
"role": "<|User|>",
|
||||
"content": f"<image_placeholder>\n{question}",
|
||||
"images": [image_array],
|
||||
},
|
||||
{"role": "<|Assistant|>", "content": ""},
|
||||
]
|
||||
|
||||
pil_images = [pil_image]
|
||||
prepare_inputs = self.vl_chat_processor(
|
||||
conversations=conversation,
|
||||
images=pil_images,
|
||||
force_batchify=True
|
||||
).to(self.cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
||||
|
||||
inputs_embeds = self.vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
outputs = self.vl_gpt.language_model.generate(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=prepare_inputs.attention_mask,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
max_new_tokens=512,
|
||||
do_sample=False if temperature == 0 else True,
|
||||
use_cache=True,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
|
||||
answer = self.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
||||
return answer
|
Loading…
Reference in New Issue
Block a user