Update run.py

This commit is contained in:
Daya Guo 2023-11-01 10:41:43 +08:00 committed by GitHub
parent c6ebdafb13
commit e0bbc1b808
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,9 +29,7 @@ def load_data(args):
:param args: Arguments.
:return: A list of examples.
"""
if args.chat_model:
prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
elif args.data_name != "math":
if args.data_name != "math":
prompt = open("prompts/gsm8k.md").read()
else:
prompt = open("prompts/math.md").read()
@ -48,10 +46,7 @@ def load_data(args):
idx = example['idx']
example['question'] = parse_question(example, args.data_name)
gt_cot, gt_ans = parse_ground_truth(example, args.data_name)
if args.chat_model:
example["input"] = f"{prompt}\n### Instruction:\n{example['question']}\nWrite a python function called 'def solution():' without any arguments to return the answer without any words. In the function, write the solution steps at first and then write the code.\n### Response:\n"
else:
example["input"] = f"{prompt}\n\nQuestion: {example['question']}\n"
example["input"] = f"{prompt}\n\nQuestion: {example['question']}\n"
example = {'idx': idx, 'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans, 'prompt': example["input"]}
samples.append(example)
@ -91,7 +86,7 @@ def inference(args):
print("=" * 50)
stop_ids = []
stop_words = ["Question","----------------","<|EOT|>"]
stop_words = ["Question","----------------"]
for x in stop_words:
ids = tokenizer.encode(x)
if tokenizer.decode(ids[-1:]) == x:
@ -185,7 +180,6 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_name", default="math", type=str)
parser.add_argument("--model_name_or_path", default="deepseek/deepseek-coder-1b-python", type=str)
parser.add_argument("--chat_model", action="store_true")
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--max_context_length", default=2048, type=int)
parser.add_argument("--max_output_length", default=512, type=int)