mirror of
https://github.com/deepseek-ai/DeepSeek-Coder.git
synced 2025-04-19 10:09:14 -04:00
update eval_instruct.py
This commit is contained in:
parent
4f0b860d30
commit
f911009816
@ -7,22 +7,20 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
data_abs_dir = Path(__file__).parent / "data"
|
data_abs_dir = Path(__file__).parent / "data"
|
||||||
|
|
||||||
from utils.utils import extract_generation_code
|
from utils.utils import extract_generation_code, languge_settings
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
from human_eval.evaluation import evaluate_functional_correctness
|
from human_eval.evaluation import evaluate_functional_correctness
|
||||||
|
|
||||||
def build_deepseekcoder_instruction(languge: str, question: str):
|
def build_deepseekcoder_instruction(languge: str, question: str):
|
||||||
return '''
|
return '''
|
||||||
Please help me to complete the function. Use the given packages only and DO NOT refer any new package. Please return all completed function in a codeblock.
|
Please continue to complete the function. You are not allowed to modify the given code and do the completion only. Please return all completed function in a codeblock. Here is the given code to do completion:
|
||||||
Here is the given code to do completion:
|
|
||||||
```{}
|
```{}
|
||||||
{}
|
{}
|
||||||
```
|
```
|
||||||
'''.strip().format(languge.lower(), question)
|
'''.strip().format(languge.lower(), question.strip())
|
||||||
|
|
||||||
|
|
||||||
def generate_one(example, lang, tokenizer, model):
|
def generate_one(example, lang, tokenizer, model):
|
||||||
prompt = build_deepseekcoder_instruction(lang, example['prompt'])
|
prompt = build_deepseekcoder_instruction(languge_settings[lang]['full_name'], example['prompt'])
|
||||||
inputs = tokenizer.apply_chat_template(
|
inputs = tokenizer.apply_chat_template(
|
||||||
[{'role': 'user', 'content': prompt }],
|
[{'role': 'user', 'content': prompt }],
|
||||||
return_tensors="pt"
|
return_tensors="pt"
|
||||||
@ -33,11 +31,14 @@ def generate_one(example, lang, tokenizer, model):
|
|||||||
|
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
inputs,
|
inputs,
|
||||||
max_new_tokens=512,
|
max_new_tokens=1024,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
top_p=0.95,
|
# top_p=0.95,
|
||||||
|
# temperature=temperature,
|
||||||
|
pad_token_id=stop_id,
|
||||||
eos_token_id=stop_id
|
eos_token_id=stop_id
|
||||||
)
|
)
|
||||||
|
|
||||||
output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
||||||
example['output'] = output
|
example['output'] = output
|
||||||
|
|
||||||
@ -49,17 +50,18 @@ def generate_main(args):
|
|||||||
saved_path = args.output_path
|
saved_path = args.output_path
|
||||||
temp_dir = args.temp_dir
|
temp_dir = args.temp_dir
|
||||||
os.makedirs(temp_dir, exist_ok=True)
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
problem_file = os.path.join(data_abs_dir, f"humaneval-{lang}.jsonl")
|
||||||
|
|
||||||
|
print("model", model_name_or_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||||
print("load tokenizer {} from {} over.".format(tokenizer.__class__, model_name_or_path))
|
print("load tokenizer {} from {} over.".format(tokenizer.__class__, model_name_or_path))
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
device_map="cuda"
|
device_map="auto",
|
||||||
|
#use_flash_attention_2=True
|
||||||
)
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
problem_file = os.path.join(data_abs_dir, f"humaneval-{lang}.jsonl")
|
|
||||||
examples = [json.loads(x) for x in open(problem_file) if x.strip()]
|
examples = [json.loads(x) for x in open(problem_file) if x.strip()]
|
||||||
print("Read {} examples for evaluation over.".format(len(examples)))
|
print("Read {} examples for evaluation over.".format(len(examples)))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user