mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-23 06:08:59 -05:00
Fix type mismatch in app_januspro.py for CPU mode
The input type is still BFloat16 when cuda_device = 'cpu', fix it
This commit is contained in:
parent
a74a59f8a9
commit
dec287ca1d
@ -22,7 +22,7 @@ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||||
else:
|
else:
|
||||||
vl_gpt = vl_gpt.to(torch.float16)
|
vl_gpt = vl_gpt.to('cpu')
|
||||||
|
|
||||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||||
tokenizer = vl_chat_processor.tokenizer
|
tokenizer = vl_chat_processor.tokenizer
|
||||||
@ -156,7 +156,7 @@ def generate_image(prompt,
|
|||||||
system_prompt='')
|
system_prompt='')
|
||||||
text = text + vl_chat_processor.image_start_tag
|
text = text + vl_chat_processor.image_start_tag
|
||||||
|
|
||||||
input_ids = torch.LongTensor(tokenizer.encode(text))
|
input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)
|
||||||
output, patches = generate(input_ids,
|
output, patches = generate(input_ids,
|
||||||
width // 16 * 16,
|
width // 16 * 16,
|
||||||
height // 16 * 16,
|
height // 16 * 16,
|
||||||
|
Loading…
Reference in New Issue
Block a user