diff --git a/.gitignore b/.gitignore index 9321f07..d6c087e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ wandb all_models/ -results/checkpoints/ \ No newline at end of file +results/checkpoints/ +results/expert_scores/*/rank_* \ No newline at end of file diff --git a/README.md b/README.md index 60042ba..622d327 100644 --- a/README.md +++ b/README.md @@ -92,10 +92,11 @@ python train.py \ torchrun --nproc-per-node=8 train_ep.py \ --base_model_path=deepseek-ai/ESFT-vanilla-lite \ - --expert_config=results/expert_configs/translation.json \ - --train_dataset=translation \ + --expert_config=results/expert_configs/intent.json \ + --train_dataset=intent \ + --save_opt_states \ --train_config=configs/base.yaml \ - --output_dir=results/checkpoints/translation + --output_dir=results/checkpoints/test/eval_intent ``` diff --git a/configs/base.yaml b/configs/base.yaml index 748fa37..a2a3077 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -8,7 +8,7 @@ n_device: 8 # Number of devices # Training settings optim: adamw_torch_fused -steps: 1 # Number of training steps +steps: 500 # Number of training steps learning_rate: 0.00001 # Learning rate weight_decay: 0.1 # Weight decay for optimizer warmup_steps: 0 # Number of warmup steps for learning rate scheduler @@ -19,8 +19,8 @@ random_concat_ratio: 0.2 # Ratio of random concatenation # Evaluation settings -eval_steps: 1 # Evaluate every X steps -save_steps: 1 # Save model every X steps +eval_steps: 100 # Evaluate every X steps +save_steps: 100 # Save model every X steps # Tokenizer settings diff --git a/scripts/train_ep.sh b/scripts/train_ep.sh index 298055c..7be4d06 100644 --- a/scripts/train_ep.sh +++ b/scripts/train_ep.sh @@ -7,6 +7,7 @@ torchrun --nproc-per-node=8 train_ep.py \ --base_model_path=${base_model_path} \ --expert_config=results/expert_configs/intent.json \ --train_dataset=intent \ + --save_opt_states \ --train_config=configs/base.yaml \ --output_dir=results/checkpoints/${exp_name} diff --git a/train_ep.py b/train_ep.py index 6c608e4..7a91e28 100644 --- a/train_ep.py +++ b/train_ep.py @@ -28,6 +28,7 @@ def main(): parser.add_argument("--train_dataset", type=str, required=True) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--train_config", type=str, required=True) + parser.add_argument("--save_opt_states", action="store_true", help="Whether to save optimizer states") parser.add_argument("--wandb_api_key", type=str, required=False) args = parser.parse_args() @@ -145,77 +146,81 @@ def main(): edp_rank = edp_group.rank() os.makedirs(output_dir, exist_ok=True) - if local_rank < ep_size and edp_rank == 0: + if ep_rank > 0 and edp_rank == 0: # Save expert model state expert_state = {k: v for k, v in self.model.state_dict().items() if ".expert" in k} expert_save_path = os.path.join(output_dir, f"expert_state_{ep_rank}.bin") - # Save expert optimizer state using parameter names instead of ids - optimizer = self.optimizer - opt_state_dict = optimizer.state_dict() + # Save expert optimizer state only if save_opt_states is True + if args.save_opt_states: + optimizer = self.optimizer + opt_state_dict = optimizer.state_dict() + + # Create a mapping from parameter id to parameter name + id_to_name = {} + for name, param in self.model.named_parameters(): + if ".expert" in name: + id_to_name[id(param)] = name + + # Get the mapping from optimizer state index to parameter + param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)} + + # Save optimizer state using parameter names as keys + expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']} + for param, idx in param_to_idx.items(): + if id(param) in id_to_name: + param_name = id_to_name[id(param)] + if idx in opt_state_dict['state']: + expert_opt_state['state'][param_name] = opt_state_dict['state'][idx] + + expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin") + # Save optimizer state + temp_opt_path = expert_opt_path + ".tmp" + torch.save(expert_opt_state, temp_opt_path) + os.replace(temp_opt_path, expert_opt_path) - # Create a mapping from parameter id to parameter name - id_to_name = {} - for name, param in self.model.named_parameters(): - if ".expert" in name: - id_to_name[id(param)] = name - - # Get the mapping from optimizer state index to parameter - param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)} - - # Save optimizer state using parameter names as keys - expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']} - for param, idx in param_to_idx.items(): - if id(param) in id_to_name: - param_name = id_to_name[id(param)] - if idx in opt_state_dict['state']: - expert_opt_state['state'][param_name] = opt_state_dict['state'][idx] - - expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin") - - # Save both states atomically + # Save model state temp_expert_path = expert_save_path + ".tmp" - temp_opt_path = expert_opt_path + ".tmp" torch.save(expert_state, temp_expert_path) - torch.save(expert_opt_state, temp_opt_path) os.sync() os.replace(temp_expert_path, expert_save_path) - os.replace(temp_opt_path, expert_opt_path) dist.barrier() if local_rank == 0: original_state = self.model.state_dict() - optimizer_state = self.optimizer.state_dict() - # Create a mapping from parameter name to optimizer index for the current session - name_to_idx = {} - for name, param in self.model.named_parameters(): - if ".expert" in name: - idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None) - if idx is not None: - name_to_idx[name] = idx + if args.save_opt_states: + optimizer_state = self.optimizer.state_dict() + # Create a mapping from parameter name to optimizer index for the current session + name_to_idx = {} + for name, param in self.model.named_parameters(): + if ".expert" in name: + idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None) + if idx is not None: + name_to_idx[name] = idx time.sleep(1) + # load expert state and optimizer state from all ranks for rank in range(1, ep_size): expert_path = os.path.join(output_dir, f"expert_state_{rank}.bin") - opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin") max_retries = 3 for retry in range(max_retries): try: expert_state = torch.load(expert_path) - expert_opt_state = torch.load(opt_path) + if args.save_opt_states: + opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin") + expert_opt_state = torch.load(opt_path) + + # Convert parameter names back to indices for the optimizer state + for param_name, state in expert_opt_state['state'].items(): + if param_name in name_to_idx: + idx = name_to_idx[param_name] + optimizer_state['state'][idx] = state # Update model state original_state.update(expert_state) - - # Convert parameter names back to indices for the optimizer state - for param_name, state in expert_opt_state['state'].items(): - if param_name in name_to_idx: - idx = name_to_idx[param_name] - optimizer_state['state'][idx] = state - break except Exception as e: if retry == max_retries - 1: @@ -223,13 +228,17 @@ def main(): time.sleep(1) original_save(output_dir, state_dict=original_state) - # Save complete optimizer state - opt_save_path = os.path.join(output_dir, "optimizer.pt") - torch.save(optimizer_state, opt_save_path) + + # Save complete optimizer state if enabled + if args.save_opt_states: + opt_save_path = os.path.join(output_dir, "optimizer.pt") + torch.save(optimizer_state, opt_save_path) + # remove those intermediate .bin files for rank in range(1, ep_size): os.remove(os.path.join(output_dir, f"expert_state_{rank}.bin")) - os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin")) + if args.save_opt_states: + os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin")) dist.barrier() tokenizer.save_pretrained(output_dir)