streamline code

This commit is contained in:
ZihanWang314 2025-05-22 07:47:10 +00:00
parent f38f67706c
commit 3746ca7441
5 changed files with 67 additions and 55 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ __pycache__/
wandb
all_models/
results/checkpoints/
results/expert_scores/*/rank_*

View File

@ -92,10 +92,11 @@ python train.py \
torchrun --nproc-per-node=8 train_ep.py \
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
--expert_config=results/expert_configs/translation.json \
--train_dataset=translation \
--expert_config=results/expert_configs/intent.json \
--train_dataset=intent \
--save_opt_states \
--train_config=configs/base.yaml \
--output_dir=results/checkpoints/translation
--output_dir=results/checkpoints/test/eval_intent
```

View File

@ -8,7 +8,7 @@ n_device: 8 # Number of devices
# Training settings
optim: adamw_torch_fused
steps: 1 # Number of training steps
steps: 500 # Number of training steps
learning_rate: 0.00001 # Learning rate
weight_decay: 0.1 # Weight decay for optimizer
warmup_steps: 0 # Number of warmup steps for learning rate scheduler
@ -19,8 +19,8 @@ random_concat_ratio: 0.2 # Ratio of random concatenation
# Evaluation settings
eval_steps: 1 # Evaluate every X steps
save_steps: 1 # Save model every X steps
eval_steps: 100 # Evaluate every X steps
save_steps: 100 # Save model every X steps
# Tokenizer settings

View File

@ -7,6 +7,7 @@ torchrun --nproc-per-node=8 train_ep.py \
--base_model_path=${base_model_path} \
--expert_config=results/expert_configs/intent.json \
--train_dataset=intent \
--save_opt_states \
--train_config=configs/base.yaml \
--output_dir=results/checkpoints/${exp_name}

View File

@ -28,6 +28,7 @@ def main():
parser.add_argument("--train_dataset", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--train_config", type=str, required=True)
parser.add_argument("--save_opt_states", action="store_true", help="Whether to save optimizer states")
parser.add_argument("--wandb_api_key", type=str, required=False)
args = parser.parse_args()
@ -145,12 +146,13 @@ def main():
edp_rank = edp_group.rank()
os.makedirs(output_dir, exist_ok=True)
if local_rank < ep_size and edp_rank == 0:
if ep_rank > 0 and edp_rank == 0:
# Save expert model state
expert_state = {k: v for k, v in self.model.state_dict().items() if ".expert" in k}
expert_save_path = os.path.join(output_dir, f"expert_state_{ep_rank}.bin")
# Save expert optimizer state using parameter names instead of ids
# Save expert optimizer state only if save_opt_states is True
if args.save_opt_states:
optimizer = self.optimizer
opt_state_dict = optimizer.state_dict()
@ -172,22 +174,24 @@ def main():
expert_opt_state['state'][param_name] = opt_state_dict['state'][idx]
expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin")
# Save both states atomically
temp_expert_path = expert_save_path + ".tmp"
# Save optimizer state
temp_opt_path = expert_opt_path + ".tmp"
torch.save(expert_state, temp_expert_path)
torch.save(expert_opt_state, temp_opt_path)
os.replace(temp_opt_path, expert_opt_path)
# Save model state
temp_expert_path = expert_save_path + ".tmp"
torch.save(expert_state, temp_expert_path)
os.sync()
os.replace(temp_expert_path, expert_save_path)
os.replace(temp_opt_path, expert_opt_path)
dist.barrier()
if local_rank == 0:
original_state = self.model.state_dict()
optimizer_state = self.optimizer.state_dict()
if args.save_opt_states:
optimizer_state = self.optimizer.state_dict()
# Create a mapping from parameter name to optimizer index for the current session
name_to_idx = {}
for name, param in self.model.named_parameters():
@ -198,24 +202,25 @@ def main():
time.sleep(1)
# load expert state and optimizer state from all ranks
for rank in range(1, ep_size):
expert_path = os.path.join(output_dir, f"expert_state_{rank}.bin")
opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin")
max_retries = 3
for retry in range(max_retries):
try:
expert_state = torch.load(expert_path)
if args.save_opt_states:
opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin")
expert_opt_state = torch.load(opt_path)
# Update model state
original_state.update(expert_state)
# Convert parameter names back to indices for the optimizer state
for param_name, state in expert_opt_state['state'].items():
if param_name in name_to_idx:
idx = name_to_idx[param_name]
optimizer_state['state'][idx] = state
# Update model state
original_state.update(expert_state)
break
except Exception as e:
if retry == max_retries - 1:
@ -223,12 +228,16 @@ def main():
time.sleep(1)
original_save(output_dir, state_dict=original_state)
# Save complete optimizer state
# Save complete optimizer state if enabled
if args.save_opt_states:
opt_save_path = os.path.join(output_dir, "optimizer.pt")
torch.save(optimizer_state, opt_save_path)
# remove those intermediate .bin files
for rank in range(1, ep_size):
os.remove(os.path.join(output_dir, f"expert_state_{rank}.bin"))
if args.save_opt_states:
os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin"))
dist.barrier()