diff --git a/demo/app_januspro.py b/demo/app_januspro.py index 702e58e..82cf6b0 100644 --- a/demo/app_januspro.py +++ b/demo/app_januspro.py @@ -22,7 +22,7 @@ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: - vl_gpt = vl_gpt.to(torch.float16) + vl_gpt = vl_gpt.to('cpu') vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer @@ -156,7 +156,7 @@ def generate_image(prompt, system_prompt='') text = text + vl_chat_processor.image_start_tag - input_ids = torch.LongTensor(tokenizer.encode(text)) + input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device) output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16,