import os
import numpy as np
import json

class MBPPDataset:

    def __init__(self, root, samplenum=1):
        """
        root: 数据文件的根目录
        """
        self.root = root
        self.data = open(os.path.join(root, "mbpp.jsonl")).readlines()

        self.clean_data = self.get_qa_only_data(self.data)
        self.prompt = []
        for i in range(1, 4):            
            prompt = self.clean_data[i]["prompt"]
            tests = "\n".join(self.clean_data[i]["test"])
            code = self.clean_data[i]["code"].replace("\r", "").replace("\t", "    ")
            prompt1 = f"You are an expert Python programmer, and here is your task: {prompt} Your code should pass these tests:\n\n{tests}\n[BEGIN]\n{code}\n[DONE]\n"
            if len(self.prompt) == 0:
                self.prompt.append(prompt1)
            else:
                self.prompt.append(self.prompt[-1] + prompt1)
        self.testdata = []
        for i in range(10, 510):
            for j in range(samplenum):
                self.testdata.append(self.clean_data[i])
        np.random.seed(1234)
        print(f"Read MBPP from {root}, number of samples {len(self.testdata)}")

    def get_qa_only_data(self, data_json):
        ans = []
        for line in data_json:
            line = json.loads(line)
            prompt = line["text"]
            suffix = line["test_list"]
            code = line["code"]
            ans.append({"prompt":prompt, "test":suffix, "code":code, "task_id":line["task_id"]})
        return ans

    def __len__(self):
        return len(self.testdata)

    def __getitem__(self, index):
        sample = self.testdata[index]
        return sample