mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-22 13:48:57 -05:00
feat: support mps device
This commit is contained in:
parent
a74a59f8a9
commit
04678b6d53
44
demo/app.py
44
demo/app.py
@ -6,6 +6,15 @@ from PIL import Image
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Device and dtype configuration
|
||||
def get_device_and_dtype():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda', torch.bfloat16 # CUDA设备使用bfloat16
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps', torch.float32
|
||||
return 'cpu', torch.float32
|
||||
|
||||
device, dtype = get_device_and_dtype()
|
||||
|
||||
# Load model and processor
|
||||
model_path = "deepseek-ai/Janus-1.3B"
|
||||
@ -15,22 +24,25 @@ language_config._attn_implementation = 'eager'
|
||||
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
language_config=language_config,
|
||||
trust_remote_code=True)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||
vl_gpt = vl_gpt.to(dtype).to(device)
|
||||
|
||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# Multimodal Understanding function
|
||||
@torch.inference_mode()
|
||||
# Multimodal Understanding function
|
||||
def multimodal_understanding(image, question, seed, top_p, temperature):
|
||||
# Clear CUDA cache before generating
|
||||
torch.cuda.empty_cache()
|
||||
# Clear CUDA cache if using CUDA
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# set seed
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if device == 'cuda':
|
||||
torch.cuda.manual_seed(seed)
|
||||
elif device == 'mps':
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
@ -44,8 +56,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
|
||||
pil_images = [Image.fromarray(image)]
|
||||
prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation, images=pil_images, force_batchify=True
|
||||
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
||||
|
||||
).to(device, dtype=dtype)
|
||||
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
@ -74,16 +85,17 @@ def generate(input_ids,
|
||||
cfg_weight: float = 5,
|
||||
image_token_num_per_image: int = 576,
|
||||
patch_size: int = 16):
|
||||
# Clear CUDA cache before generating
|
||||
torch.cuda.empty_cache()
|
||||
# Clear CUDA cache if using CUDA
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
||||
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device)
|
||||
for i in range(parallel_size * 2):
|
||||
tokens[i, :] = input_ids
|
||||
if i % 2 != 0:
|
||||
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
||||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device)
|
||||
|
||||
pkv = None
|
||||
for i in range(image_token_num_per_image):
|
||||
@ -123,11 +135,15 @@ def generate_image(prompt,
|
||||
seed=None,
|
||||
guidance=5):
|
||||
# Clear CUDA cache and avoid tracking gradients
|
||||
torch.cuda.empty_cache()
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
# Set the seed for reproducible results
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if device == 'cuda':
|
||||
torch.cuda.manual_seed(seed)
|
||||
elif device == 'mps':
|
||||
torch.mps.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
width = 384
|
||||
height = 384
|
||||
|
@ -5,7 +5,16 @@ from PIL import Image
|
||||
from diffusers.models import AutoencoderKL
|
||||
import numpy as np
|
||||
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
# 设置设备选择逻辑
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
dtype = torch.bfloat16
|
||||
elif torch.backends.mps.is_available():
|
||||
device = 'mps'
|
||||
dtype = torch.float32 # MPS设备使用float32
|
||||
else:
|
||||
device = 'cpu'
|
||||
dtype = torch.float32
|
||||
|
||||
# Load model and processor
|
||||
model_path = "deepseek-ai/JanusFlow-1.3B"
|
||||
@ -13,23 +22,24 @@ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
|
||||
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
|
||||
vl_gpt = vl_gpt.to(dtype).to(device).eval()
|
||||
|
||||
# remember to use bfloat16 dtype, this vae doesn't work with fp16
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
||||
vae = vae.to(torch.bfloat16).to(cuda_device).eval()
|
||||
vae = vae.to(dtype).to(device).eval()
|
||||
|
||||
# Multimodal Understanding function
|
||||
@torch.inference_mode()
|
||||
# Multimodal Understanding function
|
||||
def multimodal_understanding(image, question, seed, top_p, temperature):
|
||||
# Clear CUDA cache before generating
|
||||
torch.cuda.empty_cache()
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# set seed
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if device == 'cuda':
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
@ -43,8 +53,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
|
||||
pil_images = [Image.fromarray(image)]
|
||||
prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation, images=pil_images, force_batchify=True
|
||||
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
||||
|
||||
).to(device, dtype=dtype)
|
||||
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
@ -73,7 +82,7 @@ def generate(
|
||||
num_inference_steps: int = 30
|
||||
):
|
||||
# we generate 5 images at a time, *2 for CFG
|
||||
tokens = torch.stack([input_ids] * 10).cuda()
|
||||
tokens = torch.stack([input_ids] * 10).to(device)
|
||||
tokens[5:, 1:] = vl_chat_processor.pad_id
|
||||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||
print(inputs_embeds.shape)
|
||||
@ -83,13 +92,13 @@ def generate(
|
||||
|
||||
# generate with rectified flow ode
|
||||
# step 1: encode with vision_gen_enc
|
||||
z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
|
||||
z = torch.randn((5, 4, 48, 48), dtype=dtype).to(device)
|
||||
|
||||
dt = 1.0 / num_inference_steps
|
||||
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
|
||||
dt = torch.zeros_like(z).to(device).to(dtype) + dt
|
||||
|
||||
# step 2: run ode
|
||||
attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
|
||||
attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(device)
|
||||
attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
|
||||
attention_mask = attention_mask.int()
|
||||
for step in range(num_inference_steps):
|
||||
@ -108,8 +117,7 @@ def generate(
|
||||
if step == 0:
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=None)
|
||||
attention_mask=attention_mask)
|
||||
past_key_values = []
|
||||
for kv_cache in past_key_values:
|
||||
k, v = kv_cache[0], kv_cache[1]
|
||||
@ -118,8 +126,7 @@ def generate(
|
||||
else:
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
||||
use_cache=True,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values)
|
||||
attention_mask=attention_mask)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
|
||||
# transform hidden_states back to v
|
||||
@ -153,12 +160,14 @@ def generate_image(prompt,
|
||||
seed=None,
|
||||
guidance=5,
|
||||
num_inference_steps=30):
|
||||
# Clear CUDA cache and avoid tracking gradients
|
||||
torch.cuda.empty_cache()
|
||||
# Clear CUDA cache if using CUDA device
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
# Set the seed for reproducible results
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if device == 'cuda':
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
|
@ -8,8 +8,17 @@ from PIL import Image
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
# import spaces # Import spaces for ZeroGPU compatibility
|
||||
|
||||
# Device and dtype configuration
|
||||
if torch.cuda.is_available():
|
||||
device = 'cuda'
|
||||
dtype = torch.bfloat16
|
||||
elif torch.backends.mps.is_available():
|
||||
device = 'mps'
|
||||
dtype = torch.float32 # MPS设备使用float32
|
||||
else:
|
||||
device = 'cpu'
|
||||
dtype = torch.float32 # CPU设备使用float32
|
||||
|
||||
# Load model and processor
|
||||
model_path = "deepseek-ai/Janus-Pro-7B"
|
||||
@ -19,26 +28,26 @@ language_config._attn_implementation = 'eager'
|
||||
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
language_config=language_config,
|
||||
trust_remote_code=True)
|
||||
if torch.cuda.is_available():
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||
else:
|
||||
vl_gpt = vl_gpt.to(torch.float16)
|
||||
vl_gpt = vl_gpt.to(dtype).to(device)
|
||||
|
||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
@torch.inference_mode()
|
||||
# @spaces.GPU(duration=120)
|
||||
# Multimodal Understanding function
|
||||
def multimodal_understanding(image, question, seed, top_p, temperature):
|
||||
# Clear CUDA cache before generating
|
||||
torch.cuda.empty_cache()
|
||||
# Clear device cache
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
elif device == 'mps':
|
||||
torch.mps.empty_cache()
|
||||
|
||||
# set seed
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if device == 'cuda':
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
@ -52,8 +61,7 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
|
||||
pil_images = [Image.fromarray(image)]
|
||||
prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation, images=pil_images, force_batchify=True
|
||||
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
||||
|
||||
).to(device, dtype=dtype)
|
||||
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
|
||||
@ -82,16 +90,19 @@ def generate(input_ids,
|
||||
cfg_weight: float = 5,
|
||||
image_token_num_per_image: int = 576,
|
||||
patch_size: int = 16):
|
||||
# Clear CUDA cache before generating
|
||||
torch.cuda.empty_cache()
|
||||
# Clear device cache
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
elif device == 'mps':
|
||||
torch.mps.empty_cache()
|
||||
|
||||
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
||||
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device)
|
||||
for i in range(parallel_size * 2):
|
||||
tokens[i, :] = input_ids
|
||||
if i % 2 != 0:
|
||||
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
||||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device)
|
||||
|
||||
pkv = None
|
||||
for i in range(image_token_num_per_image):
|
||||
@ -133,17 +144,24 @@ def unpack(dec, width, height, parallel_size=5):
|
||||
|
||||
@torch.inference_mode()
|
||||
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
|
||||
@torch.inference_mode()
|
||||
def generate_image(prompt,
|
||||
seed=None,
|
||||
guidance=5,
|
||||
t2i_temperature=1.0):
|
||||
# Clear CUDA cache and avoid tracking gradients
|
||||
torch.cuda.empty_cache()
|
||||
# Clear device cache
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
elif device == 'mps':
|
||||
torch.mps.empty_cache()
|
||||
|
||||
# Set the seed for reproducible results
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
if device == 'cuda':
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
width = 384
|
||||
height = 384
|
||||
parallel_size = 5
|
||||
|
@ -9,6 +9,16 @@ import io
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Device and dtype configuration
|
||||
def get_device_and_dtype():
|
||||
if torch.cuda.is_available():
|
||||
return 'cuda', torch.bfloat16 # CUDA设备使用bfloat16
|
||||
elif torch.backends.mps.is_available():
|
||||
return 'mps', torch.float32
|
||||
return 'cpu', torch.float32
|
||||
|
||||
device, dtype = get_device_and_dtype()
|
||||
|
||||
# Load model and processor
|
||||
model_path = "deepseek-ai/Janus-1.3B"
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
@ -17,19 +27,25 @@ language_config._attn_implementation = 'eager'
|
||||
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
language_config=language_config,
|
||||
trust_remote_code=True)
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||
vl_gpt = vl_gpt.to(dtype).to(device)
|
||||
|
||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def multimodal_understanding(image_data, question, seed, top_p, temperature):
|
||||
torch.cuda.empty_cache()
|
||||
# Clear CUDA cache if using CUDA
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# set seed
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if device == 'cuda':
|
||||
torch.cuda.manual_seed(seed)
|
||||
elif device == 'mps':
|
||||
torch.mps.manual_seed(seed)
|
||||
|
||||
conversation = [
|
||||
{
|
||||
@ -43,7 +59,7 @@ def multimodal_understanding(image_data, question, seed, top_p, temperature):
|
||||
pil_images = [Image.open(io.BytesIO(image_data))]
|
||||
prepare_inputs = vl_chat_processor(
|
||||
conversations=conversation, images=pil_images, force_batchify=True
|
||||
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
||||
).to(device, dtype=dtype)
|
||||
|
||||
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
||||
outputs = vl_gpt.language_model.generate(
|
||||
@ -84,36 +100,39 @@ def generate(input_ids,
|
||||
cfg_weight: float = 5,
|
||||
image_token_num_per_image: int = 576,
|
||||
patch_size: int = 16):
|
||||
torch.cuda.empty_cache()
|
||||
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
||||
for i in range(parallel_size * 2):
|
||||
tokens[i, :] = input_ids
|
||||
if i % 2 != 0:
|
||||
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
||||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
|
||||
try:
|
||||
torch.cuda.empty_cache() if device == 'cuda' else None
|
||||
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device)
|
||||
for i in range(parallel_size * 2):
|
||||
tokens[i, :] = input_ids
|
||||
if i % 2 != 0:
|
||||
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
||||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device)
|
||||
|
||||
pkv = None
|
||||
for i in range(image_token_num_per_image):
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
|
||||
pkv = outputs.past_key_values
|
||||
hidden_states = outputs.last_hidden_state
|
||||
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
||||
logit_cond = logits[0::2, :]
|
||||
logit_uncond = logits[1::2, :]
|
||||
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
||||
probs = torch.softmax(logits / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
||||
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
||||
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
||||
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
||||
patches = vl_gpt.gen_vision_model.decode_code(
|
||||
generated_tokens.to(dtype=torch.int),
|
||||
shape=[parallel_size, 8, width // patch_size, height // patch_size]
|
||||
)
|
||||
pkv = None
|
||||
for i in range(image_token_num_per_image):
|
||||
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
|
||||
pkv = outputs.past_key_values
|
||||
hidden_states = outputs.last_hidden_state
|
||||
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
||||
logit_cond = logits[0::2, :]
|
||||
logit_uncond = logits[1::2, :]
|
||||
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
||||
probs = torch.softmax(logits / temperature, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
||||
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
||||
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
||||
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
||||
patches = vl_gpt.gen_vision_model.decode_code(
|
||||
generated_tokens.to(dtype=torch.int),
|
||||
shape=[parallel_size, 8, width // patch_size, height // patch_size]
|
||||
)
|
||||
|
||||
return generated_tokens.to(dtype=torch.int), patches
|
||||
return generated_tokens.to(dtype=torch.int), patches
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in generate function: {str(e)}")
|
||||
|
||||
|
||||
def unpack(dec, width, height, parallel_size=5):
|
||||
@ -128,28 +147,37 @@ def unpack(dec, width, height, parallel_size=5):
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_image(prompt, seed, guidance):
|
||||
torch.cuda.empty_cache()
|
||||
seed = seed if seed is not None else 12345
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
width = 384
|
||||
height = 384
|
||||
parallel_size = 5
|
||||
|
||||
with torch.no_grad():
|
||||
messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
|
||||
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
||||
conversations=messages,
|
||||
sft_format=vl_chat_processor.sft_format,
|
||||
system_prompt=''
|
||||
)
|
||||
text = text + vl_chat_processor.image_start_tag
|
||||
input_ids = torch.LongTensor(tokenizer.encode(text))
|
||||
_, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
|
||||
images = unpack(patches, width // 16 * 16, height // 16 * 16)
|
||||
try:
|
||||
# Clear CUDA cache if using CUDA
|
||||
if device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
# Set the seed for reproducible results
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
if device == 'cuda':
|
||||
torch.cuda.manual_seed(seed)
|
||||
elif device == 'mps':
|
||||
torch.mps.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
width = 384
|
||||
height = 384
|
||||
parallel_size = 5
|
||||
|
||||
with torch.no_grad():
|
||||
messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
|
||||
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
||||
conversations=messages,
|
||||
sft_format=vl_chat_processor.sft_format,
|
||||
system_prompt=''
|
||||
)
|
||||
text = text + vl_chat_processor.image_start_tag
|
||||
input_ids = torch.LongTensor(tokenizer.encode(text))
|
||||
_, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
|
||||
images = unpack(patches, width // 16 * 16, height // 16 * 16)
|
||||
|
||||
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
|
||||
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in generate_image function: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/generate_images/")
|
||||
|
Loading…
Reference in New Issue
Block a user