From 3746ca74410c92163cc521571dbd17579bfdc962 Mon Sep 17 00:00:00 2001
From: ZihanWang314 <510642032wzh@gmail.com>
Date: Thu, 22 May 2025 07:47:10 +0000
Subject: [PATCH] streamline code

---
 .gitignore          |   3 +-
 README.md           |   7 +--
 configs/base.yaml   |   6 +--
 scripts/train_ep.sh |   1 +
 train_ep.py         | 105 ++++++++++++++++++++++++--------------------
 5 files changed, 67 insertions(+), 55 deletions(-)

diff --git a/.gitignore b/.gitignore
index 9321f07..d6c087e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
 __pycache__/
 wandb
 all_models/
-results/checkpoints/
\ No newline at end of file
+results/checkpoints/
+results/expert_scores/*/rank_*
\ No newline at end of file
diff --git a/README.md b/README.md
index 60042ba..622d327 100644
--- a/README.md
+++ b/README.md
@@ -92,10 +92,11 @@ python train.py \
     
 torchrun --nproc-per-node=8 train_ep.py \
     --base_model_path=deepseek-ai/ESFT-vanilla-lite \
-    --expert_config=results/expert_configs/translation.json \
-    --train_dataset=translation \
+    --expert_config=results/expert_configs/intent.json \
+    --train_dataset=intent \
+    --save_opt_states \
     --train_config=configs/base.yaml \
-    --output_dir=results/checkpoints/translation
+    --output_dir=results/checkpoints/test/eval_intent
 
 ```
 
diff --git a/configs/base.yaml b/configs/base.yaml
index 748fa37..a2a3077 100644
--- a/configs/base.yaml
+++ b/configs/base.yaml
@@ -8,7 +8,7 @@ n_device: 8  # Number of devices
 
 # Training settings
 optim: adamw_torch_fused
-steps: 1  # Number of training steps
+steps: 500  # Number of training steps
 learning_rate: 0.00001  # Learning rate
 weight_decay: 0.1  # Weight decay for optimizer
 warmup_steps: 0  # Number of warmup steps for learning rate scheduler
@@ -19,8 +19,8 @@ random_concat_ratio: 0.2  # Ratio of random concatenation
 
 
 # Evaluation settings
-eval_steps: 1  # Evaluate every X steps
-save_steps: 1  # Save model every X steps
+eval_steps: 100  # Evaluate every X steps
+save_steps: 100  # Save model every X steps
 
 # Tokenizer settings
 
diff --git a/scripts/train_ep.sh b/scripts/train_ep.sh
index 298055c..7be4d06 100644
--- a/scripts/train_ep.sh
+++ b/scripts/train_ep.sh
@@ -7,6 +7,7 @@ torchrun --nproc-per-node=8 train_ep.py \
     --base_model_path=${base_model_path} \
     --expert_config=results/expert_configs/intent.json \
     --train_dataset=intent \
+    --save_opt_states \
     --train_config=configs/base.yaml \
     --output_dir=results/checkpoints/${exp_name}
 
diff --git a/train_ep.py b/train_ep.py
index 6c608e4..7a91e28 100644
--- a/train_ep.py
+++ b/train_ep.py
@@ -28,6 +28,7 @@ def main():
     parser.add_argument("--train_dataset", type=str, required=True)
     parser.add_argument("--output_dir", type=str, required=True)
     parser.add_argument("--train_config", type=str, required=True)
+    parser.add_argument("--save_opt_states", action="store_true", help="Whether to save optimizer states")
     parser.add_argument("--wandb_api_key", type=str, required=False)
     args = parser.parse_args()
 
@@ -145,77 +146,81 @@ def main():
         edp_rank = edp_group.rank()        
         os.makedirs(output_dir, exist_ok=True)
         
-        if local_rank < ep_size and edp_rank == 0:
+        if ep_rank > 0 and edp_rank == 0:
             # Save expert model state
             expert_state = {k: v for k, v in self.model.state_dict().items() if ".expert" in k}
             expert_save_path = os.path.join(output_dir, f"expert_state_{ep_rank}.bin")
             
-            # Save expert optimizer state using parameter names instead of ids
-            optimizer = self.optimizer
-            opt_state_dict = optimizer.state_dict()
+            # Save expert optimizer state only if save_opt_states is True
+            if args.save_opt_states:
+                optimizer = self.optimizer
+                opt_state_dict = optimizer.state_dict()
+                
+                # Create a mapping from parameter id to parameter name
+                id_to_name = {}
+                for name, param in self.model.named_parameters():
+                    if ".expert" in name:
+                        id_to_name[id(param)] = name
+                
+                # Get the mapping from optimizer state index to parameter
+                param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)}
+                
+                # Save optimizer state using parameter names as keys
+                expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']}
+                for param, idx in param_to_idx.items():
+                    if id(param) in id_to_name:
+                        param_name = id_to_name[id(param)]
+                        if idx in opt_state_dict['state']:
+                            expert_opt_state['state'][param_name] = opt_state_dict['state'][idx]
+                
+                expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin")
+                # Save optimizer state
+                temp_opt_path = expert_opt_path + ".tmp"
+                torch.save(expert_opt_state, temp_opt_path)
+                os.replace(temp_opt_path, expert_opt_path)
             
-            # Create a mapping from parameter id to parameter name
-            id_to_name = {}
-            for name, param in self.model.named_parameters():
-                if ".expert" in name:
-                    id_to_name[id(param)] = name
-            
-            # Get the mapping from optimizer state index to parameter
-            param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)}
-            
-            # Save optimizer state using parameter names as keys
-            expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']}
-            for param, idx in param_to_idx.items():
-                if id(param) in id_to_name:
-                    param_name = id_to_name[id(param)]
-                    if idx in opt_state_dict['state']:
-                        expert_opt_state['state'][param_name] = opt_state_dict['state'][idx]
-            
-            expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin")
-            
-            # Save both states atomically
+            # Save model state
             temp_expert_path = expert_save_path + ".tmp"
-            temp_opt_path = expert_opt_path + ".tmp"
             torch.save(expert_state, temp_expert_path)
-            torch.save(expert_opt_state, temp_opt_path)
             os.sync()
             os.replace(temp_expert_path, expert_save_path)
-            os.replace(temp_opt_path, expert_opt_path)
         
         dist.barrier()
             
         if local_rank == 0:
             original_state = self.model.state_dict()
-            optimizer_state = self.optimizer.state_dict()
             
-            # Create a mapping from parameter name to optimizer index for the current session
-            name_to_idx = {}
-            for name, param in self.model.named_parameters():
-                if ".expert" in name:
-                    idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None)
-                    if idx is not None:
-                        name_to_idx[name] = idx
+            if args.save_opt_states:
+                optimizer_state = self.optimizer.state_dict()
+                # Create a mapping from parameter name to optimizer index for the current session
+                name_to_idx = {}
+                for name, param in self.model.named_parameters():
+                    if ".expert" in name:
+                        idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None)
+                        if idx is not None:
+                            name_to_idx[name] = idx
             
             time.sleep(1)
             
+            # load expert state and optimizer state from all ranks
             for rank in range(1, ep_size):
                 expert_path = os.path.join(output_dir, f"expert_state_{rank}.bin")
-                opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin")
                 max_retries = 3
                 for retry in range(max_retries):
                     try:
                         expert_state = torch.load(expert_path)
-                        expert_opt_state = torch.load(opt_path)
+                        if args.save_opt_states:
+                            opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin")
+                            expert_opt_state = torch.load(opt_path)
+                            
+                            # Convert parameter names back to indices for the optimizer state
+                            for param_name, state in expert_opt_state['state'].items():
+                                if param_name in name_to_idx:
+                                    idx = name_to_idx[param_name]
+                                    optimizer_state['state'][idx] = state
                         
                         # Update model state
                         original_state.update(expert_state)
-                        
-                        # Convert parameter names back to indices for the optimizer state
-                        for param_name, state in expert_opt_state['state'].items():
-                            if param_name in name_to_idx:
-                                idx = name_to_idx[param_name]
-                                optimizer_state['state'][idx] = state
-                            
                         break
                     except Exception as e:
                         if retry == max_retries - 1:
@@ -223,13 +228,17 @@ def main():
                         time.sleep(1)
             
             original_save(output_dir, state_dict=original_state)
-            # Save complete optimizer state
-            opt_save_path = os.path.join(output_dir, "optimizer.pt")
-            torch.save(optimizer_state, opt_save_path)
+            
+            # Save complete optimizer state if enabled
+            if args.save_opt_states:
+                opt_save_path = os.path.join(output_dir, "optimizer.pt")
+                torch.save(optimizer_state, opt_save_path)
+            
             # remove those intermediate .bin files
             for rank in range(1, ep_size):
                 os.remove(os.path.join(output_dir, f"expert_state_{rank}.bin"))
-                os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin"))
+                if args.save_opt_states:
+                    os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin"))
         
         dist.barrier()
         tokenizer.save_pretrained(output_dir)