This commit is contained in:
Florian Beier 2025-02-01 16:28:06 +02:00 committed by GitHub
commit d84dfc3d9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,7 +22,7 @@ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
if torch.cuda.is_available():
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
vl_gpt = vl_gpt.to(torch.float16)
vl_gpt = vl_gpt.to('cpu')
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
@ -156,7 +156,7 @@ def generate_image(prompt,
system_prompt='')
text = text + vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text))
input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)
output, patches = generate(input_ids,
width // 16 * 16,
height // 16 * 16,