diff --git a/Evaluation/HumanEval/__pycache__/humaneval.cpython-38.pyc b/Evaluation/HumanEval/__pycache__/humaneval.cpython-38.pyc new file mode 100644 index 0000000..fd9d164 Binary files /dev/null and b/Evaluation/HumanEval/__pycache__/humaneval.cpython-38.pyc differ diff --git a/Evaluation/HumanEval/eval_instruct.py b/Evaluation/HumanEval/eval_instruct.py new file mode 100644 index 0000000..71874fe --- /dev/null +++ b/Evaluation/HumanEval/eval_instruct.py @@ -0,0 +1,128 @@ +import argparse +import json +import os +import torch +from pathlib import Path +from tqdm import tqdm + +data_abs_dir = Path(__file__).parent / "data" + +from utils.utils import extract_generation_code, languge_settings +from transformers import AutoTokenizer, AutoModelForCausalLM +from human_eval.evaluation import evaluate_functional_correctness + +def build_deepseekcoder_instruction(languge: str, question: str): + return ''' +Please continue to complete the function. You are not allowed to modify the given code and do the completion only. Please return all completed function in a codeblock. Here is the given code to do completion: +```{} +{} +``` +'''.strip().format(languge.lower(), question.strip()) + +def generate_one(example, lang, tokenizer, model): + prompt = build_deepseekcoder_instruction(languge_settings[lang]['full_name'], example['prompt']) + inputs = tokenizer.apply_chat_template( + [{'role': 'user', 'content': prompt }], + return_tensors="pt" + ).to(model.device) + + stop_id = tokenizer.convert_tokens_to_ids("<|EOT|>") + assert isinstance(stop_id, int), "Invalid tokenizer, EOT id not found" + + outputs = model.generate( + inputs, + max_new_tokens=1024, + do_sample=False, + # top_p=0.95, + # temperature=temperature, + pad_token_id=stop_id, + eos_token_id=stop_id + ) + + output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True) + example['output'] = output + + return extract_generation_code(example, lang_code=lang) + +def generate_main(args): + model_name_or_path = args.model + lang = args.language + saved_path = args.output_path + temp_dir = args.temp_dir + os.makedirs(temp_dir, exist_ok=True) + problem_file = os.path.join(data_abs_dir, f"humaneval-{lang}.jsonl") + + print("model", model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + print("load tokenizer {} from {} over.".format(tokenizer.__class__, model_name_or_path)) + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + torch_dtype=torch.bfloat16, + device_map="auto", + #use_flash_attention_2=True + ) + model.eval() + examples = [json.loads(x) for x in open(problem_file) if x.strip()] + print("Read {} examples for evaluation over.".format(len(examples))) + + generated_examples = [] + for ex in tqdm(examples, desc='Generating'): + gen_example = generate_one(ex, lang, tokenizer, model) + generated_examples.append(gen_example) + + print("Generate all over!!!") + with open(saved_path, 'w', encoding='utf-8') as fw: + for ex in generated_examples: + fw.write(json.dumps(ex) + '\n') + print("Save {} processed examples into {} over!".format(len(generated_examples), saved_path)) + + result = evaluate_functional_correctness( + input_file=saved_path, + tmp_dir=temp_dir, + n_workers=8, + timeout=3.0, + problem_file=problem_file, + language=lang + ) + print(lang, result, model_name_or_path) + pass + +def evaluation_only(args): + lang = args.language + temp_dir = args.temp_dir + assert os.path.exists(args.output_path), "Not fond output file: {}".format(args.output_path) + os.makedirs(temp_dir, exist_ok=True) + + output_name = os.path.basename(args.output_path) + output_examples = [json.loads(x) for x in open(args.output_path) if x.strip()] + + processed_examples = [extract_generation_code(ex, lang) for ex in tqdm(output_examples, "Processing")] + processed_path = os.path.join(temp_dir, output_name) + with open(processed_path, 'w', encoding='utf-8') as fw: + for ex in processed_examples: + fw.write(json.dumps(ex) + '\n') + print("Save {} processed examples into {} over!".format(len(processed_examples), processed_path)) + + problem_file = os.path.join(data_abs_dir, f"humaneval-{lang}.jsonl") + from human_eval.evaluation import evaluate_functional_correctness + result = evaluate_functional_correctness( + input_file=processed_path, + tmp_dir=temp_dir, + n_workers=8, + timeout=3.0, + problem_file=problem_file, + language=lang + ) + print(lang, result) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, help="model name or path") + parser.add_argument('--output_path', type=str, help="output path of your generation") + parser.add_argument('--language', type=str, help="langauge") + parser.add_argument('--temp_dir', type=str, help="temp dir for evaluation", default="tmp") + args = parser.parse_args() + + os.environ["TOKENIZERS_PARALLELISM"] = "false" + generate_main(args) + pass \ No newline at end of file diff --git a/Evaluation/HumanEval/human_eval/__pycache__/__init__.cpython-37.pyc b/Evaluation/HumanEval/human_eval/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 55b30d7..0000000 Binary files a/Evaluation/HumanEval/human_eval/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/Evaluation/HumanEval/human_eval/__pycache__/__init__.cpython-38.pyc b/Evaluation/HumanEval/human_eval/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 6208508..0000000 Binary files a/Evaluation/HumanEval/human_eval/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/Evaluation/HumanEval/human_eval/__pycache__/data.cpython-37.pyc b/Evaluation/HumanEval/human_eval/__pycache__/data.cpython-37.pyc deleted file mode 100644 index b6f6ac4..0000000 Binary files a/Evaluation/HumanEval/human_eval/__pycache__/data.cpython-37.pyc and /dev/null differ diff --git a/Evaluation/HumanEval/human_eval/__pycache__/data.cpython-38.pyc b/Evaluation/HumanEval/human_eval/__pycache__/data.cpython-38.pyc deleted file mode 100644 index 94ed78c..0000000 Binary files a/Evaluation/HumanEval/human_eval/__pycache__/data.cpython-38.pyc and /dev/null differ diff --git a/Evaluation/HumanEval/human_eval/__pycache__/evaluate_functional_correctness.cpython-38.pyc b/Evaluation/HumanEval/human_eval/__pycache__/evaluate_functional_correctness.cpython-38.pyc deleted file mode 100644 index cbcc04e..0000000 Binary files a/Evaluation/HumanEval/human_eval/__pycache__/evaluate_functional_correctness.cpython-38.pyc and /dev/null differ diff --git a/Evaluation/HumanEval/human_eval/__pycache__/evaluation.cpython-37.pyc b/Evaluation/HumanEval/human_eval/__pycache__/evaluation.cpython-37.pyc deleted file mode 100644 index aa74193..0000000 Binary files a/Evaluation/HumanEval/human_eval/__pycache__/evaluation.cpython-37.pyc and /dev/null differ diff --git a/Evaluation/HumanEval/human_eval/__pycache__/evaluation.cpython-38.pyc b/Evaluation/HumanEval/human_eval/__pycache__/evaluation.cpython-38.pyc deleted file mode 100644 index c9bac93..0000000 Binary files a/Evaluation/HumanEval/human_eval/__pycache__/evaluation.cpython-38.pyc and /dev/null differ diff --git a/Evaluation/HumanEval/human_eval/__pycache__/execution.cpython-38.pyc b/Evaluation/HumanEval/human_eval/__pycache__/execution.cpython-38.pyc deleted file mode 100644 index c9543a6..0000000 Binary files a/Evaluation/HumanEval/human_eval/__pycache__/execution.cpython-38.pyc and /dev/null differ diff --git a/Evaluation/HumanEval/utils/__pycache__/instruct.cpython-38.pyc b/Evaluation/HumanEval/utils/__pycache__/instruct.cpython-38.pyc new file mode 100644 index 0000000..51b5490 Binary files /dev/null and b/Evaluation/HumanEval/utils/__pycache__/instruct.cpython-38.pyc differ diff --git a/Evaluation/HumanEval/utils/__pycache__/utils.cpython-39.pyc b/Evaluation/HumanEval/utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000..b016ee2 Binary files /dev/null and b/Evaluation/HumanEval/utils/__pycache__/utils.cpython-39.pyc differ diff --git a/Evaluation/HumanEval/utils/utils.py b/Evaluation/HumanEval/utils/utils.py index 21b9c1b..41fcb3d 100644 --- a/Evaluation/HumanEval/utils/utils.py +++ b/Evaluation/HumanEval/utils/utils.py @@ -1,3 +1,109 @@ +import re + +languge_settings = { + 'python': { + 'full_name': 'Python', + 'indent': 4, + }, + 'cpp': { + 'full_name': 'cpp', + 'indent': 0, + 'main': "int main()", + }, + 'java': { + 'full_name': 'Java', + 'indent': 4, + 'main': "public static void main", + }, + 'cs': { + 'full_name': "csharp", + 'indent': 0, + 'main': "public static void Main", + }, + 'php': { + 'full_name': "PHP", + 'indent': 0, + }, + 'ts': { + 'full_name': "TypeScript", + 'indent': 0, + }, + 'js': { + 'full_name': "JavaScript", + 'indent': 0 + }, + 'sh': { + 'full_name': "Bash", + 'indent': 0 + } +} + +def get_function_name(question: str, lang: str): + func_lines = [x for x in question.strip().split('\n') if x.strip()] + + if lang.lower() == 'python': + func_idx = [i for i in range(len(func_lines)) if func_lines[i].startswith("def ")][-1] + func_name = func_lines[func_idx].split('(')[0].strip() + func_prefix = "\n".join(func_lines[:func_idx]) + return func_name, func_prefix + + func_name = func_lines[-1].split('{')[0].strip() + func_prefix = "\n".join(func_lines[:-1]) + return func_name, func_prefix + +def extract_generation_code(example: str, lang_code: str, verbose: bool=False): + task_id = example['task_id'] + output = example.get('output', example.get("gpt_completion")) + question = example["prompt"].strip() + setting = languge_settings[lang_code] + lang = setting['full_name'] + indent = setting['indent'] + + try: + code_block: str = re.findall(f'```{lang.lower()}\n(.*?)```', output, re.DOTALL | re.IGNORECASE)[0] + if verbose: + print(">>> Task: {}\n{}".format(task_id, code_block)) + + # Remove main + if setting.get('main', None) and setting['main'] in code_block: + main_start = code_block.index(setting['main']) + code_block = code_block[:main_start] + + func_name, func_prefix = get_function_name(question, lang) + + try: + start = code_block.lower().index(func_name.lower()) + indent = 0 + while start - indent >= 0 and code_block[start - indent-1] == ' ': + indent += 1 + + try: + end = code_block.rindex('\n' + ' '*indent + '}') + except: + end = len(code_block) + except: + start = 0 + try: + end = code_block.rindex('\n' + ' '*indent + '}') + except: + end = len(code_block) + + body = code_block[start:end] + + if lang_code.lower() in ['php', 'ts', 'js']: + body += '\n' + ' '*indent + '}' + + generation = func_prefix + '\n' + body + '\n' + example['generation'] = generation + + except Exception as ex: + print("Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format( + ex, task_id, output + )) + example['generation'] = example['prompt'] + '\n' + output + + return example + def cleanup_code( code: str, language_type: str = None, diff --git a/finetune/README.md b/finetune/README.md new file mode 100644 index 0000000..4e71e06 --- /dev/null +++ b/finetune/README.md @@ -0,0 +1,44 @@ +## How to Fine-tune DeepSeek-Coder + +We provide script `finetune_deepseekcoder.py` for users to finetune our models on downstream tasks. + +The script supports the training with [DeepSpeed](https://github.com/microsoft/DeepSpeed). You need install required packages by: + +```bash +pip install -r requirements.txt +``` + +Please follow [Sample Dataset Format](https://huggingface.co/datasets/nickrosh/Evol-Instruct-Code-80k-v1) to prepare your training data. +Each line is a json-serialized string with two required fields `instruction` and `output`. + +After data preparation, you can use the sample shell script to finetune `deepseek-ai/deepseek-coder-6.7b-instruct`. +Remember to specify `DATA_PATH`, `OUTPUT_PATH`. +And please choose appropriate hyper-parameters(e.g., `learning_rate`, `per_device_train_batch_size`) according to your scenario. + +```bash +DATA_PATH="" +OUTPUT_PATH="" +MODEL="deepseek-ai/deepseek-coder-6.7b-instruct" + +deepspeed finetune_deepseekcoder.py \ + --model_name_or_path $MODEL_PATH \ + --data_path $DATA_PATH \ + --output_dir $OUTPUT_PATH \ + --num_train_epochs 3 \ + --model_max_length 1024 \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 100 \ + --save_total_limit 100 \ + --learning_rate 2e-5 \ + --warmup_steps 10 \ + --logging_steps 1 \ + --lr_scheduler_type "cosine" \ + --gradient_checkpointing True \ + --report_to "tensorboard" \ + --deepspeed configs/ds_config_zero3.json \ + --bf16 True +``` \ No newline at end of file diff --git a/finetune/configs/ds_config_zero3.json b/finetune/configs/ds_config_zero3.json new file mode 100644 index 0000000..73f3b5f --- /dev/null +++ b/finetune/configs/ds_config_zero3.json @@ -0,0 +1,51 @@ +{ + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu", + "pin_memory": true + }, + "offload_param": { + "device": "cpu", + "pin_memory": true + }, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 20, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/finetune/finetune_deepseekcoder.py b/finetune/finetune_deepseekcoder.py new file mode 100644 index 0000000..91bd297 --- /dev/null +++ b/finetune/finetune_deepseekcoder.py @@ -0,0 +1,207 @@ +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +from dataclasses import dataclass, field +from typing import Optional, Dict, Sequence + +import torch +import torch.distributed +import transformers +from transformers import Trainer +from datasets import load_dataset + + +IGNORE_INDEX = -100 +EOT_TOKEN = "<|EOT|>" + +def build_instruction_prompt(instruction: str): + return ''' +You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer. +### Instruction: +{} +### Response: +'''.format(instruction.strip()).lstrip() + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="deepseek-ai/deepseek-coder-6.7b-instruct") + +@dataclass +class DataArguments: + data_path: str = field(default=None, metadata={"help": "Path to the training data."}) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): + """Collects the state dict and dump to disk.""" + state_dict = trainer.model.state_dict() + if trainer.args.should_save: + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} + del state_dict + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa + + +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: + """Tokenize a list of strings.""" + tokenized_list = [ + tokenizer( + text, + return_tensors="pt", + padding="longest", + max_length=tokenizer.model_max_length, + truncation=True, + ) + for text in strings + ] + + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] + input_ids_lens = labels_lens = [ + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list + ] + + return dict( + input_ids=input_ids, + labels=labels, + input_ids_lens=input_ids_lens, + labels_lens=labels_lens, + ) + + +def preprocess( + sources: Sequence[str], + targets: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, +) -> Dict: + """Preprocess the data by tokenizing.""" + examples = [s + t for s, t in zip(sources, targets)] + examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] + input_ids = examples_tokenized["input_ids"] + + labels = copy.deepcopy(input_ids) + for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): + label[:source_len] = IGNORE_INDEX + return dict(input_ids=input_ids, labels=labels) + +@dataclass +class DataCollatorForSupervisedDataset(object): + """Collate examples for supervised fine-tuning.""" + tokenizer: transformers.PreTrainedTokenizer + + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) + input_ids = [torch.tensor(x) for x in input_ids] + input_ids = torch.nn.utils.rnn.pad_sequence( + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id + ) + labels = [torch.tensor(x) for x in labels] + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) + + return dict( + input_ids=input_ids, + labels=labels, + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), + ) + +def train_tokenize_function(examples, tokenizer): + sources = [ + build_instruction_prompt(instruction) + for instruction in examples['instruction'] + ] + targets = [f"{output}\n{EOT_TOKEN}" for output in examples['output']] + data_dict = preprocess(sources, targets, tokenizer) + return data_dict + +def train(): + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.local_rank == 0: + print('='*100) + print(training_args) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + model_max_length=training_args.model_max_length, + padding_side="right", + use_fast=True, + trust_remote_code=True + ) + + print("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id) + print("BOS Token", tokenizer.bos_token, tokenizer.bos_token_id) + print("EOS Token", tokenizer.eos_token, tokenizer.eos_token_id) + + if training_args.local_rank == 0: + print("Load tokenizer from {} over.".format(model_args.model_name_or_path)) + + model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + torch_dtype=torch.bfloat16 + ) + + if training_args.local_rank == 0: + print("Load model from {} over.".format(model_args.model_name_or_path)) + + + raw_train_datasets = load_dataset( + 'json', + data_files=data_args.data_path, + split="train", + cache_dir=training_args.cache_dir + ) + if training_args.local_rank > 0: + torch.distributed.barrier() + + train_dataset = raw_train_datasets.map( + train_tokenize_function, + batched=True, + batch_size=3000, + num_proc=32, + remove_columns=raw_train_datasets.column_names, + load_from_cache_file=True, # not args.overwrite_cache + desc="Running Encoding", + fn_kwargs={ "tokenizer": tokenizer } + ) + + if training_args.local_rank == 0: + torch.distributed.barrier() + + if training_args.local_rank == 0: + print("Training dataset samples:", len(train_dataset)) + for index in random.sample(range(len(train_dataset)), 3): + print(f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.") + print(f"Sample {index} of the training set: {tokenizer.decode(list(train_dataset[index]['input_ids']))}.") + + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) + data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) + + trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) + + trainer.train() + trainer.save_state() + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/finetune/requirements.txt b/finetune/requirements.txt new file mode 100644 index 0000000..f6e03b7 --- /dev/null +++ b/finetune/requirements.txt @@ -0,0 +1,10 @@ +torch>=2.0.1 +tokenizers>=0.14.0 +transformers>=4.35.0 +accelerate +attrdict +tqdm + +deepspeed +datasets +tensorboardX