Add MPS compatibility and dtype fixes

- Added app_januspro_mps.py: optimized for Apple MPS, with automatic device selection
- Fixed dtype mismatch issues in vq_model.py to ensure stability on MPS
- Updated README.md to document MPS improvements and call for community testing
- Contribution based on community testing and AI-assisted debugging
This commit is contained in:
YJT 2025-01-28 19:20:01 -08:00
parent a74a59f8a9
commit 0cbf6f146b
3 changed files with 412 additions and 12 deletions

View File

@ -287,10 +287,12 @@ generate(
```
### Gradio Demo
We have deployed online demo in [Huggingface](https://huggingface.co/spaces/deepseek-ai/Janus-Pro-7B).
We have deployed an online demo in [Huggingface](https://huggingface.co/spaces/deepseek-ai/Janus-Pro-7B).
For the local Gradio demo, you can run one of the following commands:
For the local gradio demo, you can run with the following command:
**For standard CUDA-based inference:**
```
pip install -e .[gradio]
@ -298,6 +300,20 @@ pip install -e .[gradio]
python demo/app_januspro.py
```
**For Apple Silicon (MPS) users (experimental):**
```
pip install -e .[gradio]
python demo/app_januspro_mps.py
```
This version includes optimizations for Apple Silicon (MPS), using `torch.float16` instead of `torch.bfloat16`.
*Note:* This is an experimental script contributed by the community and has not been officially tested by the DeepSeek team. Please share feedback if you encounter issues!
Have Fun!
</details>
@ -710,11 +726,76 @@ Have Fun!
</details>
## 4. License
## 4. Community Contributions
This repository welcomes community contributions that improve the models usability across different platforms.
---
### **🔹 Apple Silicon (MPS) Compatibility & Performance Fixes**
**Issue:** The original `app_januspro.py` script ran inference **on the CPU instead of utilizing the MPS (Metal Performance Shaders) backend**, leading to slow performance. Additionally, dtype mismatches between **bfloat16 (input) and float16 (bias)** caused runtime errors.
**Solution:**
- A new script, **`app_januspro_mps.py`**, has been added, optimized for **Apple MPS**.
- The script **prioritizes MPS acceleration when available**, significantly improving performance.
- It **remains compatible with CUDA and CPU**, though **further community testing is encouraged**.
**Key Improvements:**
- **Automatic device selection** (`cuda`, `mps`, or `cpu`).
- **Ensures dtype consistency**:
- **MPS:** `float16` (to prevent dtype mismatches).
- **CUDA:** `bfloat16` (or `float16`, if preferred).
- **CPU:** `float32` (fallback for compatibility).
- **Fixes dtype mismatches** that previously caused crashes on MPS.
**Usage:**
For Apple Silicon (MPS) users, try the new script:
```sh
pip install -e .[gradio]
python demo/app_januspro_mps.py
💡 This script is an experimental addition. If it performs well across all platforms, the community can consider merging improvements into the main app_januspro.py.
---
### **🔹 Fix for RuntimeError: Mismatched DType (`bfloat16` vs `half`)**
**Issue:** Running Janus-Pro-7B on **Apple's MPS backend** previously resulted in:
```
RuntimeError: Input type (c10::BFloat16) and bias type (c10::Half) should be the same
```
This was caused by **dtype mismatches**:
- The **Upsample block** converted inputs to `float32`, applied interpolation, then cast them to `bfloat16`.
- Meanwhile, **convolution layers used `float16`**, leading to an error.
- **Apples MPS has partial support for `bfloat16`**, contributing to instability.
**Solution:**
- We **standardized all tensor dtypes to `torch.float16` on MPS** to prevent mismatches.
- This change ensures **stable execution across MPS and CUDA**.
### **🔹 Fixes Applied to `vq_model.py`**
**Issue:** Additional dtype mismatches were identified in **`vq_model.py`**, specifically in the **Upsample module**, where tensor operations introduced unnecessary conversions.
**Solution:**
- The **Upsample module now preserves the original input tensor dtype**, ensuring dtype consistency throughout the pipeline.
- This prevents unexpected dtype mismatches across **MPS, CUDA, and CPU environments**.
### **🔹 Status & Call for Community Testing**
The **new `app_januspro_mps.py` script and dtype fixes in `vq_model.py` have been tested successfully on Apple Silicon**. While initial results indicate improved performance, **further validation from the DeepSeek community is encouraged**.
🚀 If you have expertise in **PyTorch, Apple MPS, or GPU optimization**, your feedback and improvements are welcome!
If you encounter issues, please **open an Issue or submit a Pull Request**.
## 5. License
This code repository is licensed under [the MIT License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-CODE). The use of Janus models is subject to [DeepSeek Model License](https://github.com/deepseek-ai/DeepSeek-LLM/blob/HEAD/LICENSE-MODEL).
## 5. Citation
## 6. Citation
```bibtex
@misc{chen2025januspro,
@ -738,6 +819,7 @@ This code repository is licensed under [the MIT License](https://github.com/deep
}
```
## 6. Contact
## 7. Contact
If you have any questions, please raise an issue or contact us at [service@deepseek.com](mailto:service@deepseek.com).

323
demo/app_januspro_mps.py Normal file
View File

@ -0,0 +1,323 @@
"""
app_januspro_v4.py
An updated version of your Janus Pro demo script forcing float16 on MPS,
ensuring the main model and the vision submodule share the same dtype.
"""
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from PIL import Image
import numpy as np
import os
import time
# 1. Detect device (cuda vs. mps vs. cpu)
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# 2. Choose dtype
# - We'll use bfloat16 on CUDA if you prefer that, but you can use float16 if desired.
# - We force float16 on MPS to avoid any mismatch. CPU -> float32 fallback.
if device == "cuda":
dtype = torch.bfloat16 # or torch.float16 if you want half on CUDA
elif device == "mps":
dtype = torch.float16 # definitely float16 on Apple MPS to avoid mismatch
else:
dtype = torch.float32
print(f"Using device = {device}, dtype = {dtype}")
# 3. Load model config & model
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
# If needed, force some config changes:
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(
model_path,
language_config=language_config,
trust_remote_code=True
)
# 4. Move entire model to the chosen device & dtype
vl_gpt = vl_gpt.to(device, dtype=dtype)
# 4a. Explicitly recast the vision submodule in case it didn't propagate
# This helps if the vision submodel is loaded or stored differently.
if hasattr(vl_gpt, "gen_vision_model"):
vl_gpt.gen_vision_model = vl_gpt.gen_vision_model.to(device, dtype=dtype)
# Debug prints: just to confirm
print(">>> Top-level param dtype:", next(vl_gpt.parameters()).dtype)
print(">>> Vision model param dtype:",
next(vl_gpt.gen_vision_model.parameters()).dtype if hasattr(vl_gpt, "gen_vision_model") else "N/A")
# 5. Load processor
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
# 6. Utility to clear device cache (no-op for MPS/CPU)
def clear_device_cache():
if device == "cuda":
torch.cuda.empty_cache()
# 7. Unified seed setting
def set_seed(seed: int):
torch.manual_seed(seed)
np.random.seed(seed)
if device == "cuda":
torch.cuda.manual_seed(seed)
# 8. Multimodal Understanding
@torch.inference_mode()
def multimodal_understanding(image, question, seed, top_p, temperature):
clear_device_cache()
set_seed(int(seed))
conversation = [
{
"role": "<|User|>",
"content": f"<image_placeholder>\n{question}",
"images": [image],
},
{"role": "<|Assistant|>", "content": ""},
]
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=pil_images,
force_batchify=True
).to(device=device, dtype=dtype)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=(False if temperature == 0 else True),
use_cache=True,
temperature=temperature,
top_p=top_p,
)
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return answer
# 9. Low-level image generation logic
def generate(input_ids,
width,
height,
temperature: float = 1,
parallel_size: int = 2,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
clear_device_cache()
tokens = torch.zeros(
(parallel_size * 2, len(input_ids)),
dtype=torch.int,
device=device
)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros(
(parallel_size, image_token_num_per_image),
dtype=torch.int,
device=device
)
pkv = None
for i in range(image_token_num_per_image):
with torch.no_grad():
outputs = vl_gpt.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=pkv
)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
# Conditioned vs. Unconditioned
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
# Classifier-free guidance
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
# Sample
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
# Next token also goes to uncond
next_token = torch.cat(
[next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)],
dim=1
).view(-1)
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
# Force the correct dtype if needed
if img_embeds.dtype != dtype:
img_embeds = img_embeds.to(dtype)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = vl_gpt.gen_vision_model.decode_code(
generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size]
)
return generated_tokens.to(dtype=torch.int), patches
def unpack(dec, width, height, parallel_size=5):
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255).astype(np.uint8)
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img
# 10. Text-to-Image Generation
@torch.inference_mode()
def generate_image(prompt,
seed=None,
guidance=5,
t2i_temperature=1.0):
clear_device_cache()
if seed is not None:
set_seed(int(seed))
width = 384
height = 384
parallel_size = 2
with torch.no_grad():
messages = [
{'role': '<|User|>', 'content': prompt},
{'role': '<|Assistant|>', 'content': ''}
]
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=messages,
sft_format=vl_chat_processor.sft_format,
system_prompt=''
)
text = text + vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text)).to(device)
output, patches = generate(
input_ids,
width=(width // 16 * 16),
height=(height // 16 * 16),
cfg_weight=guidance,
parallel_size=parallel_size,
temperature=t2i_temperature
)
images = unpack(
patches,
width=(width // 16 * 16),
height=(height // 16 * 16),
parallel_size=parallel_size
)
pil_images = [
Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS)
for i in range(parallel_size)
]
return pil_images
# 11. Gradio Interface
with gr.Blocks() as demo:
gr.Markdown(value="# Multimodal Understanding")
with gr.Row():
image_input = gr.Image()
with gr.Column():
question_input = gr.Textbox(label="Question")
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
understanding_button = gr.Button("Chat")
understanding_output = gr.Textbox(label="Response")
examples_inpainting = gr.Examples(
label="Multimodal Understanding examples",
examples=[
[
"explain this meme",
"images/doge.png",
],
[
"Convert the formula into latex code.",
"images/equation.png",
],
],
inputs=[question_input, image_input],
)
gr.Markdown(value="# Text-to-Image Generation")
with gr.Row():
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
prompt_input = gr.Textbox(label="Prompt. (More detail => better images!)")
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
generation_button = gr.Button("Generate Images")
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
examples_t2i = gr.Examples(
label="Text to image generation examples.",
examples=[
"Master shifu racoon wearing drip attire as a street gangster.",
"The face of a beautiful girl",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"A glass of red wine on a reflective surface.",
"A cute and adorable baby fox with big brown eyes...",
"The image features an intricately designed eye set against a circular backdrop...",
],
inputs=prompt_input,
)
understanding_button.click(
multimodal_understanding,
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
outputs=understanding_output
)
generation_button.click(
fn=generate_image,
inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
outputs=image_output
)
demo.launch(share=True)

View File

@ -415,13 +415,8 @@ class Upsample(nn.Module):
)
def forward(self, x):
if x.dtype != torch.float32:
x = F.interpolate(x.to(torch.float), scale_factor=2.0, mode="nearest").to(
torch.bfloat16
)
else:
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
# Remove forced casting. Just interpolate in x's current dtype.
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x