mirror of
https://github.com/deepseek-ai/Janus.git
synced 2025-02-22 13:48:57 -05:00
Fix type mismatch in app_januspro.py for CPU mode
The input type is still BFloat16 when cuda_device = 'cpu', fix it
This commit is contained in:
parent
a74a59f8a9
commit
dec287ca1d
@ -22,7 +22,7 @@ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
|
||||
if torch.cuda.is_available():
|
||||
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
||||
else:
|
||||
vl_gpt = vl_gpt.to(torch.float16)
|
||||
vl_gpt = vl_gpt.to('cpu')
|
||||
|
||||
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
||||
tokenizer = vl_chat_processor.tokenizer
|
||||
@ -156,7 +156,7 @@ def generate_image(prompt,
|
||||
system_prompt='')
|
||||
text = text + vl_chat_processor.image_start_tag
|
||||
|
||||
input_ids = torch.LongTensor(tokenizer.encode(text))
|
||||
input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)
|
||||
output, patches = generate(input_ids,
|
||||
width // 16 * 16,
|
||||
height // 16 * 16,
|
||||
|
Loading…
Reference in New Issue
Block a user