"""
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
"""

import dataclasses
from enum import IntEnum, auto
from typing import Any, Dict, List


class SeparatorStyle(IntEnum):
    """Separator styles."""

    DeepSeek = auto()
    DeepSeekV2 = auto()
    PLAIN = auto()
    ALIGNMENT = auto()


@dataclasses.dataclass
class Conversation:
    """A class that manages prompt templates and keeps all conversation history."""

    # The name of this template
    name: str
    # The template of the system prompt
    system_template: str = "{system_message}"
    # The system message
    system_message: str = ""
    # The names of two roles
    roles: List[str] = (("USER", "ASSISTANT"),)
    # All messages. Each item is (role, message).
    messages: List[List[str]] = ()
    # The number of few shot examples
    offset: int = 0
    # The separator style and configurations
    sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
    sep: str = "\n"
    sep2: str = None
    # Stop criteria (the default one is EOS token)
    stop_str: str = None
    # Stops generation if meeting any token in this list
    stop_token_ids: List[int] = None

    def get_prompt(self) -> str:
        """Get the prompt for generation."""
        system_prompt = self.system_template.format(system_message=self.system_message)
        if self.sep_style == SeparatorStyle.DeepSeek:
            seps = [self.sep, self.sep2]
            if system_prompt == "" or system_prompt is None:
                ret = ""
            else:
                ret = system_prompt + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    ret += role + ": " + message + seps[i % 2]
                else:
                    ret += role + ":"
            return ret
        elif self.sep_style == SeparatorStyle.DeepSeekV2:
            seps = [self.sep, self.sep2]
            if system_prompt == "" or system_prompt is None:
                ret = ""
            else:
                ret = system_prompt + seps[0]
            for i, (role, message) in enumerate(self.messages):
                if message:
                    if role == "User":
                        ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
                    else:
                        ret += message + self.sep2
                else:
                    ret = ret
            return ret

        elif self.sep_style == SeparatorStyle.PLAIN:
            seps = [self.sep, self.sep2]
            ret = ""
            for i, (role, message) in enumerate(self.messages):
                if message:
                    if type(message) is tuple:
                        message, _, _ = message
                    if i % 2 == 0:
                        ret += message + seps[i % 2]
                    else:
                        ret += message + seps[i % 2]
                else:
                    ret += ""
            return ret
        elif self.sep_style == SeparatorStyle.ALIGNMENT:
            seps = [self.sep, self.sep2]
            ret = ""
            for i, (role, message) in enumerate(self.messages):
                if message:
                    if type(message) is tuple:
                        message, _, _ = message
                    if i % 2 == 0:
                        ret += '<image>\n' + seps[i % 2]
                    else:
                        ret += message + seps[i % 2]
                else:
                    ret += ""
            return ret
        else:
            raise ValueError(f"Invalid style: {self.sep_style}")

    def set_system_message(self, system_message: str):
        """Set the system message."""
        self.system_message = system_message

    def append_message(self, role: str, message: str):
        """Append a new message."""
        self.messages.append([role, message])

    def update_last_message(self, message: str):
        """Update the last output.

        The last message is typically set to be None when constructing the prompt,
        so we need to update it in-place after getting the response from a model.
        """
        self.messages[-1][1] = message

    def reset_message(self):
        """Reset a new message."""
        self.messages = []

    def to_gradio_chatbot(self):
        """Convert the conversation to gradio chatbot format."""
        ret = []
        for i, (role, msg) in enumerate(self.messages[self.offset :]):
            if i % 2 == 0:
                ret.append([msg, None])
            else:
                ret[-1][-1] = msg
        return ret

    def to_openai_api_messages(self):
        """Convert the conversation to OpenAI chat completion format."""
        system_prompt = self.system_template.format(system_message=self.system_message)
        ret = [{"role": "system", "content": system_prompt}]

        for i, (_, msg) in enumerate(self.messages[self.offset :]):
            if i % 2 == 0:
                ret.append({"role": "user", "content": msg})
            else:
                if msg is not None:
                    ret.append({"role": "assistant", "content": msg})
        return ret

    def copy(self):
        return Conversation(
            name=self.name,
            system_template=self.system_template,
            system_message=self.system_message,
            roles=self.roles,
            messages=[[x, y] for x, y in self.messages],
            offset=self.offset,
            sep_style=self.sep_style,
            sep=self.sep,
            sep2=self.sep2,
            stop_str=self.stop_str,
            stop_token_ids=self.stop_token_ids,
        )

    def dict(self):
        return {
            "template_name": self.name,
            "system_message": self.system_message,
            "roles": self.roles,
            "messages": self.messages,
            "offset": self.offset,
        }


# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}


def register_conv_template(template: Conversation, override: bool = False):
    """Register a new conversation template."""
    if not override:
        assert template.name not in conv_templates, f"{template.name} has been registered."

    conv_templates[template.name] = template


def get_conv_template(name: str) -> Conversation:
    """Get a conversation template."""
    return conv_templates[name].copy()


# register_conv_template(
#     Conversation(
#         name="deepseek",
#         system_template="{system_message}",
#         # system_message="You are a helpful assistant. Please answer truthfully and write out your "
#         # "thinking step by step to be sure you get the right answer.",
#         system_message="",
#         roles=("User", "Assistant"),
#         messages=(),
#         offset=0,
#         sep_style=SeparatorStyle.DeepSeek,
#         sep="\n\n",
#         sep2="<|end▁of▁sentence|>",
#         stop_token_ids=[100001],
#         stop_str=["User:", "<|end▁of▁sentence|>"]
#     )
# )
register_conv_template(
    Conversation(
        name="deepseek",
        system_template="{system_message}",
        # system_message="You are a helpful assistant. Please answer truthfully and write out your "
        # "thinking step by step to be sure you get the right answer.",
        system_message="",
        roles=("<|User|>", "<|Assistant|>"),
        messages=(),
        offset=0,
        sep_style=SeparatorStyle.DeepSeek,
        sep="\n\n",
        sep2="<|end▁of▁sentence|>",
        stop_token_ids=[100001],
        stop_str=["User:", "<|end▁of▁sentence|>"]
    )
)
# register_conv_template(
#     Conversation(
#         name="deepseekv2",
#         system_template="{system_message}",
#         system_message="",
#         roles=("User", "Assistant"),
#         messages=(),
#         offset=0,
#         sep_style=SeparatorStyle.DeepSeekV2,
#         sep="\n<|sft▁end|>",
#         sep2="<|end▁of▁sentence|>",
#         stop_token_ids=[100001],
#         stop_str=["User:", "<|end▁of▁sentence|>"]
#     )
# )
register_conv_template(
    Conversation(
        name="deepseekv2",
        system_template="{system_message}",
        system_message="",
        roles=("|<User>|", "|<Assistant>|"),
        messages=(),
        offset=0,
        sep_style=SeparatorStyle.DeepSeekV2,
        sep="\n<|sft▁end|>",
        sep2="<|end▁of▁sentence|>",
        stop_token_ids=[100001],
        stop_str=["User:", "<|end▁of▁sentence|>"]
    )
)


register_conv_template(
    Conversation(
        name="plain",
        system_template="",
        system_message="",
        roles=("", ""),
        messages=(),
        offset=0,
        sep_style=SeparatorStyle.PLAIN,
        sep="",
        sep2="",
        stop_token_ids=[100001],
        stop_str=['</s>'],
    )
)


register_conv_template(
    Conversation(
        name="alignment",
        system_template="",
        system_message="",
        roles=("", ""),
        messages=(),
        offset=0,
        sep_style=SeparatorStyle.ALIGNMENT,
        sep="",
        sep2="",
        stop_token_ids=[100001],
        stop_str=['</s>'],
    )
)


if __name__ == "__main__":
    print("deepseek template:")
    conv = get_conv_template("deepseek")
    conv.append_message(conv.roles[0], "Hello!")
    conv.append_message(conv.roles[1], "Hi! This is Tony.")
    conv.append_message(conv.roles[0], "Who are you?")
    conv.append_message(conv.roles[1], "I am a helpful assistant.")
    conv.append_message(conv.roles[0], "How are you?")
    conv.append_message(conv.roles[1], None)
    print(conv.get_prompt())

    print("deepseekv2 template:")
    conv = get_conv_template("deepseekv2")
    conv.append_message(conv.roles[0], "Hello!")
    conv.append_message(conv.roles[1], "Hi! This is Tony.")
    conv.append_message(conv.roles[0], "Who are you?")
    conv.append_message(conv.roles[1], "I am a helpful assistant.")
    conv.append_message(conv.roles[0], "How are you?")
    conv.append_message(conv.roles[1], None)
    print(conv.get_prompt())