mirror of
https://github.com/deepseek-ai/DeepSeek-MoE.git
synced 2025-02-23 06:09:05 -05:00
improve dataset mapping
This commit is contained in:
parent
66edeee5a4
commit
44bf0630ef
@ -180,12 +180,16 @@ def train_tokenize_function(examples, tokenizer):
|
||||
return data_dict
|
||||
|
||||
def build_model(model_args, training_args, checkpoint_dir):
|
||||
if not model_args.use_lora: assert model_args.bits in [16, 32]
|
||||
logger.info("Starting model building process...")
|
||||
if not model_args.use_lora:
|
||||
assert model_args.bits in [16, 32]
|
||||
logger.info(f"Not using LoRA. Model bits: {model_args.bits}")
|
||||
compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16)
|
||||
logger.info(f"Compute dtype: {compute_dtype}")
|
||||
|
||||
logger.info(f"Loading model from: {model_args.model_name_or_path}")
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
load_in_4bit=model_args.bits == 4,
|
||||
load_in_8bit=model_args.bits == 8,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=model_args.bits == 4,
|
||||
load_in_8bit=model_args.bits == 8,
|
||||
@ -197,22 +201,29 @@ def build_model(model_args, training_args, checkpoint_dir):
|
||||
) if model_args.use_lora else None,
|
||||
torch_dtype=compute_dtype,
|
||||
trust_remote_code=True,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
)
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
if compute_dtype == torch.float16 and model_args.bits == 4:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
logger.info('='*80)
|
||||
logger.info('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
|
||||
logger.info('='*80)
|
||||
|
||||
logger.info("Setting model attributes...")
|
||||
setattr(model, 'model_parallel', True)
|
||||
setattr(model, 'is_parallelizable', True)
|
||||
model.config.torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32
|
||||
# Tokenizer
|
||||
logger.info(f"Model torch dtype set to: {model.config.torch_dtype}")
|
||||
|
||||
if model_args.use_lora and model_args.bits < 16:
|
||||
logger.info("Preparing model for k-bit training...")
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
||||
logger.info("Model prepared for k-bit training")
|
||||
|
||||
if model_args.use_lora:
|
||||
logger.info("LoRA is enabled. Proceeding with LoRA setup...")
|
||||
if checkpoint_dir is not None:
|
||||
logger.info(f"Loading adapters from {checkpoint_dir}.")
|
||||
# os.path.join(checkpoint_dir, 'adapter_model')
|
||||
@ -220,12 +231,20 @@ def build_model(model_args, training_args, checkpoint_dir):
|
||||
else:
|
||||
logger.info(f'Init LoRA modules...')
|
||||
target_modules = model_args.trainable.split(',')
|
||||
logger.info(f"Target modules for LoRA: {target_modules}")
|
||||
|
||||
modules_to_save = model_args.modules_to_save
|
||||
if modules_to_save is not None:
|
||||
modules_to_save = modules_to_save.split(',')
|
||||
logger.info(f"Modules to save: {modules_to_save}")
|
||||
else:
|
||||
logger.info("No modules to save specified")
|
||||
|
||||
lora_rank = model_args.lora_rank
|
||||
lora_dropout = model_args.lora_dropout
|
||||
lora_alpha = model_args.lora_alpha
|
||||
logger.info(f"LoRA parameters: rank={lora_rank}, dropout={lora_dropout}, alpha={lora_alpha}")
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
target_modules=target_modules,
|
||||
@ -233,7 +252,10 @@ def build_model(model_args, training_args, checkpoint_dir):
|
||||
r=lora_rank, lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
modules_to_save=modules_to_save)
|
||||
logger.info(f"LoRA configuration: {peft_config}")
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
logger.info("LoRA model preparation completed")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
@ -291,7 +313,7 @@ def train():
|
||||
train_tokenize_function,
|
||||
batched=True,
|
||||
batch_size=3000,
|
||||
num_proc=32,
|
||||
num_proc=os.cpu_count(),
|
||||
remove_columns=raw_train_datasets.column_names,
|
||||
load_from_cache_file=True, # not args.overwrite_cache
|
||||
desc="Running Encoding",
|
||||
|
Loading…
Reference in New Issue
Block a user