DeepSeek-Coder-V2/finetune/finetune_deepseekcoder.py
2024-08-25 02:16:52 +02:00

367 lines
15 KiB
Python

import copy
import random
import logging
import os
import torch
import torch.distributed
import transformers
import datasets
import numpy as np
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
from transformers import Trainer, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, PeftModel
from peft.tuners.lora import LoraLayer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from rich import print
from loguru import logger
IGNORE_INDEX = -100
EOT_TOKEN = "<|EOT|>"
def build_instruction_prompt(instruction: str):
return '''
You are an AI programming assistant, utilizing the DeepSeek Coder model, developed by DeepSeek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.
### Instruction:
{}
### Response:
'''.format(instruction.strip()).lstrip()
@dataclass
class ModelArguments:
trainable : Optional[str] = field(default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj")
lora_rank : Optional[int] = field(default=8)
lora_dropout : Optional[float] = field(default=0.1)
lora_alpha : Optional[float] = field(default=32.)
modules_to_save : Optional[str] = field(default="embed_tokens,lm_head")
use_lora : Optional[bool] = field(default=False)
model_name_or_path: Optional[str] = field(default="deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")
attn_implementation : Optional[str] = field(default="flash_attention_2")
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(
default=16,
metadata={"help": "How many bits to use."}
)
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
class SavePeftModelCallback(transformers.TrainerCallback):
def save_model(self, args, state, kwargs):
logger.info('Saving PEFT checkpoint...')
if state.best_model_checkpoint is not None:
checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model")
else:
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
kwargs["tokenizer"].save_pretrained(peft_model_path)
def on_save(self, args, state, control, **kwargs):
self.save_model(args, state, kwargs)
return control
def on_train_end(self, args, state, control, **kwargs):
def touch(fname, times=None):
with open(fname, 'a'):
os.utime(fname, times)
touch(os.path.join(args.output_dir, 'completed'))
self.save_model(args, state, kwargs)
logger.info("Saved model successfully")
def get_last_checkpoint(checkpoint_dir):
if os.path.isdir(checkpoint_dir):
is_completed = os.path.exists(os.path.join(checkpoint_dir, 'completed'))
if is_completed: return None # already finished
max_step = 0
for filename in os.listdir(checkpoint_dir):
if os.path.isdir(os.path.join(checkpoint_dir, filename)) and filename.startswith(PREFIX_CHECKPOINT_DIR):
max_step = max(max_step, int(filename.replace(PREFIX_CHECKPOINT_DIR + '-', '')))
if max_step == 0: return None
latest_ckpt_dir = os.path.join(checkpoint_dir, f'{PREFIX_CHECKPOINT_DIR}-{max_step}')
logger.info(f"Found a previous checkpoint at: {checkpoint_dir}")
return latest_ckpt_dir
return None # first training
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = [torch.tensor(x) for x in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = [torch.tensor(x) for x in labels]
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def train_tokenize_function(examples, tokenizer):
sources = [
build_instruction_prompt(instruction)
for instruction in examples['instruction']
]
targets = [f"{output}\n{EOT_TOKEN}" for output in examples['output']]
data_dict = preprocess(sources, targets, tokenizer)
return data_dict
def build_model(model_args, training_args, checkpoint_dir):
logger.info("Starting model building process...")
if not model_args.use_lora:
assert model_args.bits in [16, 32]
logger.info(f"Not using LoRA. Model bits: {model_args.bits}")
compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16)
logger.info(f"Compute dtype: {compute_dtype}")
logger.info(f"Loading model from: {model_args.model_name_or_path}")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
quantization_config=BitsAndBytesConfig(
load_in_4bit=model_args.bits == 4,
load_in_8bit=model_args.bits == 8,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quant,
bnb_4bit_quant_type=model_args.quant_type,
) if model_args.use_lora else None,
torch_dtype=compute_dtype,
trust_remote_code=True,
attn_implementation=model_args.attn_implementation,
)
logger.info("Model loaded successfully")
if compute_dtype == torch.float16 and model_args.bits == 4:
if torch.cuda.is_bf16_supported():
logger.info('='*80)
logger.info('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
logger.info('='*80)
logger.info("Setting model attributes...")
setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True)
model.config.torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32
logger.info(f"Model torch dtype set to: {model.config.torch_dtype}")
if model_args.use_lora and model_args.bits < 16:
logger.info("Preparing model for k-bit training...")
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
logger.info("Model prepared for k-bit training")
if model_args.use_lora:
logger.info("LoRA is enabled. Proceeding with LoRA setup...")
if checkpoint_dir is not None:
logger.info(f"Loading adapters from {checkpoint_dir}.")
model = PeftModel.from_pretrained(model, checkpoint_dir, is_trainable=True)
else:
logger.info(f'Init LoRA modules...')
target_modules = model_args.trainable.split(',')
logger.info(f"Target modules for LoRA: {target_modules}")
modules_to_save = model_args.modules_to_save
if modules_to_save is not None:
modules_to_save = modules_to_save.split(',')
logger.info(f"Modules to save: {modules_to_save}")
else:
logger.info("No modules to save specified")
lora_rank = model_args.lora_rank
lora_dropout = model_args.lora_dropout
lora_alpha = model_args.lora_alpha
logger.info(f"LoRA parameters: rank={lora_rank}, dropout={lora_dropout}, alpha={lora_alpha}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=lora_rank, lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
modules_to_save=modules_to_save)
logger.info(f"LoRA configuration: {peft_config}")
model = get_peft_model(model, peft_config)
logger.info("LoRA model preparation completed")
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name or 'gate' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
return model
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
log_level = training_args.get_process_log_level()
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
if training_args.local_rank == 0:
logger.info('='*100)
logger.info(f"Training arguments: {training_args}")
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
trust_remote_code=True
)
print("PAD Token:", tokenizer.pad_token, tokenizer.pad_token_id)
print("BOS Token:", tokenizer.bos_token, tokenizer.bos_token_id)
print("EOS Token:", tokenizer.eos_token, tokenizer.eos_token_id)
if training_args.local_rank == 0:
logger.info("Load tokenizer from {} over.".format(model_args.model_name_or_path))
resume_from_checkpoint_dir = get_last_checkpoint(training_args.output_dir)
model = build_model(model_args, training_args, resume_from_checkpoint_dir)
if training_args.local_rank == 0:
logger.info("Load model from {} over.".format(model_args.model_name_or_path))
raw_train_datasets = load_dataset(
'json', # can be also parquet, csv, etc.
data_files=data_args.data_path,
split="train",
cache_dir=training_args.cache_dir
)
if training_args.local_rank > 0:
torch.distributed.barrier()
logger.info("Starting dataset mapping")
train_dataset = raw_train_datasets.map(
train_tokenize_function,
batched=True,
batch_size=3000,
num_proc=os.cpu_count(),
remove_columns=raw_train_datasets.column_names,
load_from_cache_file=True, # not args.overwrite_cache
desc="Running Encoding",
fn_kwargs={ "tokenizer": tokenizer }
)
logger.info("Dataset mapping completed")
if training_args.local_rank == 0:
torch.distributed.barrier()
if training_args.local_rank == 0:
logger.info(f"Training dataset samples: {len(train_dataset)}")
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]['input_ids']}, {train_dataset[index]['labels']}.")
logger.info(f"Sample {index} decoded: {tokenizer.decode(list(train_dataset[index]['input_ids']))}")
logger.info("Creating data collator")
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
logger.info("Setting up data module")
data_module = dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
logger.info("Initializing Trainer")
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
if model_args.use_lora:
logger.info("Adding SavePeftModelCallback for LoRA")
trainer.add_callback(SavePeftModelCallback)
logger.info("Starting training")
trainer.train(resume_from_checkpoint = resume_from_checkpoint_dir)
logger.info("Saving trainer state")
trainer.save_state()
if not model_args.use_lora:
logger.info("Saving full model (non-LoRA)")
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
logger.info("Training completed")
if __name__ == "__main__":
train()