mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-04-19 18:18:57 -04:00
fix: demo/fastapi_app.py with mps device.
This commit is contained in:
parent
877c778c0e
commit
1b69d7f99b
@ -12,9 +12,9 @@ app = FastAPI()
|
|||||||
# Device and dtype configuration
|
# Device and dtype configuration
|
||||||
def get_device_and_dtype():
|
def get_device_and_dtype():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return 'cuda', torch.bfloat16
|
return 'cuda', torch.float32
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
return 'mps', torch.float16
|
return 'mps', torch.float32
|
||||||
return 'cpu', torch.float32
|
return 'cpu', torch.float32
|
||||||
|
|
||||||
device, dtype = get_device_and_dtype()
|
device, dtype = get_device_and_dtype()
|
||||||
@ -35,11 +35,17 @@ tokenizer = vl_chat_processor.tokenizer
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def multimodal_understanding(image_data, question, seed, top_p, temperature):
|
def multimodal_understanding(image_data, question, seed, top_p, temperature):
|
||||||
torch.cuda.empty_cache() if device == 'cuda' else None
|
# Clear CUDA cache if using CUDA
|
||||||
|
if device == 'cuda':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# set seed
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
elif device == 'mps':
|
||||||
|
torch.mps.manual_seed(seed)
|
||||||
|
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
@ -94,36 +100,39 @@ def generate(input_ids,
|
|||||||
cfg_weight: float = 5,
|
cfg_weight: float = 5,
|
||||||
image_token_num_per_image: int = 576,
|
image_token_num_per_image: int = 576,
|
||||||
patch_size: int = 16):
|
patch_size: int = 16):
|
||||||
torch.cuda.empty_cache() if device == 'cuda' else None
|
try:
|
||||||
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device)
|
torch.cuda.empty_cache() if device == 'cuda' else None
|
||||||
for i in range(parallel_size * 2):
|
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(device)
|
||||||
tokens[i, :] = input_ids
|
for i in range(parallel_size * 2):
|
||||||
if i % 2 != 0:
|
tokens[i, :] = input_ids
|
||||||
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
if i % 2 != 0:
|
||||||
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
tokens[i, 1:-1] = vl_chat_processor.pad_id
|
||||||
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device)
|
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
||||||
|
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(device)
|
||||||
|
|
||||||
pkv = None
|
pkv = None
|
||||||
for i in range(image_token_num_per_image):
|
for i in range(image_token_num_per_image):
|
||||||
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
|
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
|
||||||
pkv = outputs.past_key_values
|
pkv = outputs.past_key_values
|
||||||
hidden_states = outputs.last_hidden_state
|
hidden_states = outputs.last_hidden_state
|
||||||
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
||||||
logit_cond = logits[0::2, :]
|
logit_cond = logits[0::2, :]
|
||||||
logit_uncond = logits[1::2, :]
|
logit_uncond = logits[1::2, :]
|
||||||
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
|
||||||
probs = torch.softmax(logits / temperature, dim=-1)
|
probs = torch.softmax(logits / temperature, dim=-1)
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
generated_tokens[:, i] = next_token.squeeze(dim=-1)
|
||||||
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
|
||||||
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
||||||
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
||||||
patches = vl_gpt.gen_vision_model.decode_code(
|
patches = vl_gpt.gen_vision_model.decode_code(
|
||||||
generated_tokens.to(dtype=torch.int),
|
generated_tokens.to(dtype=torch.int),
|
||||||
shape=[parallel_size, 8, width // patch_size, height // patch_size]
|
shape=[parallel_size, 8, width // patch_size, height // patch_size]
|
||||||
)
|
)
|
||||||
|
|
||||||
return generated_tokens.to(dtype=torch.int), patches
|
return generated_tokens.to(dtype=torch.int), patches
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error in generate function: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def unpack(dec, width, height, parallel_size=5):
|
def unpack(dec, width, height, parallel_size=5):
|
||||||
@ -138,29 +147,37 @@ def unpack(dec, width, height, parallel_size=5):
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_image(prompt, seed, guidance):
|
def generate_image(prompt, seed, guidance):
|
||||||
torch.cuda.empty_cache() if device == 'cuda' else None
|
try:
|
||||||
seed = seed if seed is not None else 12345
|
# Clear CUDA cache if using CUDA
|
||||||
torch.manual_seed(seed)
|
if device == 'cuda':
|
||||||
if device == 'cuda':
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.manual_seed(seed)
|
# Set the seed for reproducible results
|
||||||
np.random.seed(seed)
|
if seed is not None:
|
||||||
width = 384
|
torch.manual_seed(seed)
|
||||||
height = 384
|
if device == 'cuda':
|
||||||
parallel_size = 5
|
torch.cuda.manual_seed(seed)
|
||||||
|
elif device == 'mps':
|
||||||
|
torch.mps.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
width = 384
|
||||||
|
height = 384
|
||||||
|
parallel_size = 5
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
|
messages = [{'role': 'User', 'content': prompt}, {'role': 'Assistant', 'content': ''}]
|
||||||
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
|
||||||
conversations=messages,
|
conversations=messages,
|
||||||
sft_format=vl_chat_processor.sft_format,
|
sft_format=vl_chat_processor.sft_format,
|
||||||
system_prompt=''
|
system_prompt=''
|
||||||
)
|
)
|
||||||
text = text + vl_chat_processor.image_start_tag
|
text = text + vl_chat_processor.image_start_tag
|
||||||
input_ids = torch.LongTensor(tokenizer.encode(text))
|
input_ids = torch.LongTensor(tokenizer.encode(text))
|
||||||
_, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
|
_, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size)
|
||||||
images = unpack(patches, width // 16 * 16, height // 16 * 16)
|
images = unpack(patches, width // 16 * 16, height // 16 * 16)
|
||||||
|
|
||||||
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
|
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(parallel_size)]
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error in generate_image function: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate_images/")
|
@app.post("/generate_images/")
|
||||||
|
Loading…
Reference in New Issue
Block a user