# -*- coding: utf-8 -*-

import argparse
import os
import sys
from PIL import Image
import readline
from threading import Thread
import torch
from transformers import TextIteratorStreamer

from deepseek_vl.utils.io import load_pretrained_model


def load_image(image_file):
    image = Image.open(image_file).convert("RGB")
    return image


def get_help_message(image_token):
    help_msg = (
        f"\t\t DeepSeek-VL-Chat is a chatbot that can answer questions based on the given image. Enjoy it! \n"
        f"Usage: \n"
        f"    1. type `exit` to quit. \n"
        f"    2. type `{image_token}` to indicate there is an image. You can enter multiple images, "
        f"e.g '{image_token} is a dot, {image_token} is a cat, and what is it in {image_token}?'. "
        f"When you type `{image_token}`, the chatbot will ask you to input image file path. \n"
        f"    4. type `help` to get the help messages. \n"
        f"    5. type `new` to start a new conversation. \n"
        f"    Here is an example, you can type: '<image_placeholder>Describe the image.'\n"
    )

    return help_msg


@torch.inference_mode()
def response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config):

    prompt = conv.get_prompt()
    prepare_inputs = vl_chat_processor.__call__(
        prompt=prompt,
        images=pil_images,
        force_batchify=True
    ).to(vl_gpt.device)

    # run image encoder to get the image embeddings
    inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

    streamer = TextIteratorStreamer(
        tokenizer=tokenizer,
        skip_prompt=True,
        skip_special_tokens=True
    )
    generation_config["inputs_embeds"] = inputs_embeds
    generation_config["attention_mask"] = prepare_inputs.attention_mask
    generation_config["streamer"] = streamer

    thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config)
    thread.start()

    yield from streamer


def get_user_input(hint: str):
    user_input = ""
    while user_input == "":
        try:
            user_input = input(f"{hint}")
        except KeyboardInterrupt:
            print()
            continue
        except EOFError:
            user_input = "exit"

    return user_input


def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config):
    image_token = vl_chat_processor.image_token
    help_msg = get_help_message(image_token)

    while True:

        print(help_msg)

        pil_images = []
        conv = vl_chat_processor.new_chat_template()
        roles = conv.roles

        while True:

            # get user input
            user_input = get_user_input(f"{roles[0]} [{image_token} indicates an image]: ")

            if user_input == "exit":
                print("Chat program exited.")
                sys.exit(0)

            elif user_input == "help":
                print(help_msg)

            elif user_input == "new":
                os.system("clear")
                pil_images = []
                conv = vl_chat_processor.new_chat_template()
                torch.cuda.empty_cache()
                print("New conversation started.")

            else:
                conv.append_message(conv.roles[0], user_input)
                conv.append_message(conv.roles[1], None)

                # check if the user input is an image token
                num_images = user_input.count(image_token)
                cur_img_idx = 0

                while cur_img_idx < num_images:
                    try:
                        image_file = input(f"({cur_img_idx + 1}/{num_images}) Input the image file path: ")

                    except KeyboardInterrupt:
                        print()
                        continue

                    except EOFError:
                        image_file = None

                    if image_file and os.path.exists(image_file):
                        pil_image = load_image(image_file)
                        pil_images.append(pil_image)
                        cur_img_idx += 1

                    elif image_file == "exit":
                        print("Chat program exited.")
                        sys.exit(0)

                    else:
                        print(f"File error, `{image_file}` does not exist. Please input the correct file path.")

                # get the answer by the model's prediction
                answer = ""
                answer_iter = response(args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config)
                sys.stdout.write(f"{conv.roles[1]}: ")
                for char in answer_iter:
                    answer += char
                    sys.stdout.write(char)
                    sys.stdout.flush()

                sys.stdout.write("\n")
                sys.stdout.flush()
                conv.messages[-1][-1] = answer


def main(args):

    # setup
    tokenizer, vl_chat_processor, vl_gpt = load_pretrained_model(args.model_path)
    generation_config = dict(
        pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
        bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
        eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
        max_new_tokens=args.max_gen_len,
        use_cache=True,
    )
    if args.temperature > 0:
        generation_config.update({
            "do_sample": True,
            "top_p": args.top_p,
            "temperature": args.temperature,
            "repetition_penalty": args.repetition_penalty,
        })
    else:
        generation_config.update({"do_sample": False})

    chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="deepseek-ai/deepseek-vl-7b-chat",
                        help="the huggingface model name or the local path of the downloaded huggingface model.")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--repetition_penalty", type=float, default=1.1)
    parser.add_argument("--max_gen_len", type=int, default=512)
    args = parser.parse_args()
    main(args)

"""

CUDA_VISIBLE_DEVICES=4 python cli.py --model_path "/home/liuwen/3fs_shared/ckpts/deepseek-vl-7b-chat"
CUDA_VISIBLE_DEVICES=4 python cli.py --model_path "deepseek-ai/deepseek-vl-7b-chat"

"""