diff --git a/finetune/finetune.py b/finetune/finetune.py index 244ff4e..49c63b2 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -184,8 +184,6 @@ def build_model(model_args, training_args, checkpoint_dir): compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16) model = transformers.AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, - load_in_4bit=model_args.bits == 4, - load_in_8bit=model_args.bits == 8, quantization_config=BitsAndBytesConfig( load_in_4bit=model_args.bits == 4, load_in_8bit=model_args.bits == 8,