improve dataset mapping

This commit is contained in:
muhtasham 2024-08-25 01:47:30 +02:00 committed by GitHub
parent 66edeee5a4
commit 44bf0630ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -180,12 +180,16 @@ def train_tokenize_function(examples, tokenizer):
return data_dict return data_dict
def build_model(model_args, training_args, checkpoint_dir): def build_model(model_args, training_args, checkpoint_dir):
if not model_args.use_lora: assert model_args.bits in [16, 32] logger.info("Starting model building process...")
if not model_args.use_lora:
assert model_args.bits in [16, 32]
logger.info(f"Not using LoRA. Model bits: {model_args.bits}")
compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16) compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16)
logger.info(f"Compute dtype: {compute_dtype}")
logger.info(f"Loading model from: {model_args.model_name_or_path}")
model = transformers.AutoModelForCausalLM.from_pretrained( model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
load_in_4bit=model_args.bits == 4,
load_in_8bit=model_args.bits == 8,
quantization_config=BitsAndBytesConfig( quantization_config=BitsAndBytesConfig(
load_in_4bit=model_args.bits == 4, load_in_4bit=model_args.bits == 4,
load_in_8bit=model_args.bits == 8, load_in_8bit=model_args.bits == 8,
@ -197,22 +201,29 @@ def build_model(model_args, training_args, checkpoint_dir):
) if model_args.use_lora else None, ) if model_args.use_lora else None,
torch_dtype=compute_dtype, torch_dtype=compute_dtype,
trust_remote_code=True, trust_remote_code=True,
attn_implementation=model_args.attn_implementation,
) )
logger.info("Model loaded successfully")
if compute_dtype == torch.float16 and model_args.bits == 4: if compute_dtype == torch.float16 and model_args.bits == 4:
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
logger.info('='*80) logger.info('='*80)
logger.info('Your GPU supports bfloat16, you can accelerate training with the argument --bf16') logger.info('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
logger.info('='*80) logger.info('='*80)
logger.info("Setting model attributes...")
setattr(model, 'model_parallel', True) setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True) setattr(model, 'is_parallelizable', True)
model.config.torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32 model.config.torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32
# Tokenizer logger.info(f"Model torch dtype set to: {model.config.torch_dtype}")
if model_args.use_lora and model_args.bits < 16: if model_args.use_lora and model_args.bits < 16:
logger.info("Preparing model for k-bit training...")
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
logger.info("Model prepared for k-bit training")
if model_args.use_lora: if model_args.use_lora:
logger.info("LoRA is enabled. Proceeding with LoRA setup...")
if checkpoint_dir is not None: if checkpoint_dir is not None:
logger.info(f"Loading adapters from {checkpoint_dir}.") logger.info(f"Loading adapters from {checkpoint_dir}.")
# os.path.join(checkpoint_dir, 'adapter_model') # os.path.join(checkpoint_dir, 'adapter_model')
@ -220,12 +231,20 @@ def build_model(model_args, training_args, checkpoint_dir):
else: else:
logger.info(f'Init LoRA modules...') logger.info(f'Init LoRA modules...')
target_modules = model_args.trainable.split(',') target_modules = model_args.trainable.split(',')
logger.info(f"Target modules for LoRA: {target_modules}")
modules_to_save = model_args.modules_to_save modules_to_save = model_args.modules_to_save
if modules_to_save is not None: if modules_to_save is not None:
modules_to_save = modules_to_save.split(',') modules_to_save = modules_to_save.split(',')
logger.info(f"Modules to save: {modules_to_save}")
else:
logger.info("No modules to save specified")
lora_rank = model_args.lora_rank lora_rank = model_args.lora_rank
lora_dropout = model_args.lora_dropout lora_dropout = model_args.lora_dropout
lora_alpha = model_args.lora_alpha lora_alpha = model_args.lora_alpha
logger.info(f"LoRA parameters: rank={lora_rank}, dropout={lora_dropout}, alpha={lora_alpha}")
peft_config = LoraConfig( peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, task_type=TaskType.CAUSAL_LM,
target_modules=target_modules, target_modules=target_modules,
@ -233,7 +252,10 @@ def build_model(model_args, training_args, checkpoint_dir):
r=lora_rank, lora_alpha=lora_alpha, r=lora_rank, lora_alpha=lora_alpha,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
modules_to_save=modules_to_save) modules_to_save=modules_to_save)
logger.info(f"LoRA configuration: {peft_config}")
model = get_peft_model(model, peft_config) model = get_peft_model(model, peft_config)
logger.info("LoRA model preparation completed")
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, LoraLayer): if isinstance(module, LoraLayer):
@ -291,7 +313,7 @@ def train():
train_tokenize_function, train_tokenize_function,
batched=True, batched=True,
batch_size=3000, batch_size=3000,
num_proc=32, num_proc=os.cpu_count(),
remove_columns=raw_train_datasets.column_names, remove_columns=raw_train_datasets.column_names,
load_from_cache_file=True, # not args.overwrite_cache load_from_cache_file=True, # not args.overwrite_cache
desc="Running Encoding", desc="Running Encoding",