mirror of
https://github.com/deepseek-ai/DeepSeek-Coder.git
synced 2025-06-20 00:14:03 -04:00
Update run.py
This commit is contained in:
parent
c6ebdafb13
commit
e0bbc1b808
@ -29,9 +29,7 @@ def load_data(args):
|
|||||||
:param args: Arguments.
|
:param args: Arguments.
|
||||||
:return: A list of examples.
|
:return: A list of examples.
|
||||||
"""
|
"""
|
||||||
if args.chat_model:
|
if args.data_name != "math":
|
||||||
prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
|
||||||
elif args.data_name != "math":
|
|
||||||
prompt = open("prompts/gsm8k.md").read()
|
prompt = open("prompts/gsm8k.md").read()
|
||||||
else:
|
else:
|
||||||
prompt = open("prompts/math.md").read()
|
prompt = open("prompts/math.md").read()
|
||||||
@ -48,10 +46,7 @@ def load_data(args):
|
|||||||
idx = example['idx']
|
idx = example['idx']
|
||||||
example['question'] = parse_question(example, args.data_name)
|
example['question'] = parse_question(example, args.data_name)
|
||||||
gt_cot, gt_ans = parse_ground_truth(example, args.data_name)
|
gt_cot, gt_ans = parse_ground_truth(example, args.data_name)
|
||||||
if args.chat_model:
|
example["input"] = f"{prompt}\n\nQuestion: {example['question']}\n"
|
||||||
example["input"] = f"{prompt}\n### Instruction:\n{example['question']}\nWrite a python function called 'def solution():' without any arguments to return the answer without any words. In the function, write the solution steps at first and then write the code.\n### Response:\n"
|
|
||||||
else:
|
|
||||||
example["input"] = f"{prompt}\n\nQuestion: {example['question']}\n"
|
|
||||||
example = {'idx': idx, 'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans, 'prompt': example["input"]}
|
example = {'idx': idx, 'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans, 'prompt': example["input"]}
|
||||||
samples.append(example)
|
samples.append(example)
|
||||||
|
|
||||||
@ -91,7 +86,7 @@ def inference(args):
|
|||||||
print("=" * 50)
|
print("=" * 50)
|
||||||
|
|
||||||
stop_ids = []
|
stop_ids = []
|
||||||
stop_words = ["Question","----------------","<|EOT|>"]
|
stop_words = ["Question","----------------"]
|
||||||
for x in stop_words:
|
for x in stop_words:
|
||||||
ids = tokenizer.encode(x)
|
ids = tokenizer.encode(x)
|
||||||
if tokenizer.decode(ids[-1:]) == x:
|
if tokenizer.decode(ids[-1:]) == x:
|
||||||
@ -185,7 +180,6 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--data_name", default="math", type=str)
|
parser.add_argument("--data_name", default="math", type=str)
|
||||||
parser.add_argument("--model_name_or_path", default="deepseek/deepseek-coder-1b-python", type=str)
|
parser.add_argument("--model_name_or_path", default="deepseek/deepseek-coder-1b-python", type=str)
|
||||||
parser.add_argument("--chat_model", action="store_true")
|
|
||||||
parser.add_argument("--batch_size", default=16, type=int)
|
parser.add_argument("--batch_size", default=16, type=int)
|
||||||
parser.add_argument("--max_context_length", default=2048, type=int)
|
parser.add_argument("--max_context_length", default=2048, type=int)
|
||||||
parser.add_argument("--max_output_length", default=512, type=int)
|
parser.add_argument("--max_output_length", default=512, type=int)
|
||||||
|
Loading…
Reference in New Issue
Block a user