streamline code

This commit is contained in:
ZihanWang314 2025-05-22 07:47:10 +00:00
parent f38f67706c
commit 3746ca7441
5 changed files with 67 additions and 55 deletions

1
.gitignore vendored
View File

@ -2,3 +2,4 @@ __pycache__/
wandb
all_models/
results/checkpoints/
results/expert_scores/*/rank_*

View File

@ -92,10 +92,11 @@ python train.py \
torchrun --nproc-per-node=8 train_ep.py \
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
--expert_config=results/expert_configs/translation.json \
--train_dataset=translation \
--expert_config=results/expert_configs/intent.json \
--train_dataset=intent \
--save_opt_states \
--train_config=configs/base.yaml \
--output_dir=results/checkpoints/translation
--output_dir=results/checkpoints/test/eval_intent
```

View File

@ -8,7 +8,7 @@ n_device: 8 # Number of devices
# Training settings
optim: adamw_torch_fused
steps: 1 # Number of training steps
steps: 500 # Number of training steps
learning_rate: 0.00001 # Learning rate
weight_decay: 0.1 # Weight decay for optimizer
warmup_steps: 0 # Number of warmup steps for learning rate scheduler
@ -19,8 +19,8 @@ random_concat_ratio: 0.2 # Ratio of random concatenation
# Evaluation settings
eval_steps: 1 # Evaluate every X steps
save_steps: 1 # Save model every X steps
eval_steps: 100 # Evaluate every X steps
save_steps: 100 # Save model every X steps
# Tokenizer settings

View File

@ -7,6 +7,7 @@ torchrun --nproc-per-node=8 train_ep.py \
--base_model_path=${base_model_path} \
--expert_config=results/expert_configs/intent.json \
--train_dataset=intent \
--save_opt_states \
--train_config=configs/base.yaml \
--output_dir=results/checkpoints/${exp_name}

View File

@ -28,6 +28,7 @@ def main():
parser.add_argument("--train_dataset", type=str, required=True)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--train_config", type=str, required=True)
parser.add_argument("--save_opt_states", action="store_true", help="Whether to save optimizer states")
parser.add_argument("--wandb_api_key", type=str, required=False)
args = parser.parse_args()
@ -145,77 +146,81 @@ def main():
edp_rank = edp_group.rank()
os.makedirs(output_dir, exist_ok=True)
if local_rank < ep_size and edp_rank == 0:
if ep_rank > 0 and edp_rank == 0:
# Save expert model state
expert_state = {k: v for k, v in self.model.state_dict().items() if ".expert" in k}
expert_save_path = os.path.join(output_dir, f"expert_state_{ep_rank}.bin")
# Save expert optimizer state using parameter names instead of ids
optimizer = self.optimizer
opt_state_dict = optimizer.state_dict()
# Save expert optimizer state only if save_opt_states is True
if args.save_opt_states:
optimizer = self.optimizer
opt_state_dict = optimizer.state_dict()
# Create a mapping from parameter id to parameter name
id_to_name = {}
for name, param in self.model.named_parameters():
if ".expert" in name:
id_to_name[id(param)] = name
# Create a mapping from parameter id to parameter name
id_to_name = {}
for name, param in self.model.named_parameters():
if ".expert" in name:
id_to_name[id(param)] = name
# Get the mapping from optimizer state index to parameter
param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)}
# Get the mapping from optimizer state index to parameter
param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)}
# Save optimizer state using parameter names as keys
expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']}
for param, idx in param_to_idx.items():
if id(param) in id_to_name:
param_name = id_to_name[id(param)]
if idx in opt_state_dict['state']:
expert_opt_state['state'][param_name] = opt_state_dict['state'][idx]
# Save optimizer state using parameter names as keys
expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']}
for param, idx in param_to_idx.items():
if id(param) in id_to_name:
param_name = id_to_name[id(param)]
if idx in opt_state_dict['state']:
expert_opt_state['state'][param_name] = opt_state_dict['state'][idx]
expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin")
expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin")
# Save optimizer state
temp_opt_path = expert_opt_path + ".tmp"
torch.save(expert_opt_state, temp_opt_path)
os.replace(temp_opt_path, expert_opt_path)
# Save both states atomically
# Save model state
temp_expert_path = expert_save_path + ".tmp"
temp_opt_path = expert_opt_path + ".tmp"
torch.save(expert_state, temp_expert_path)
torch.save(expert_opt_state, temp_opt_path)
os.sync()
os.replace(temp_expert_path, expert_save_path)
os.replace(temp_opt_path, expert_opt_path)
dist.barrier()
if local_rank == 0:
original_state = self.model.state_dict()
optimizer_state = self.optimizer.state_dict()
# Create a mapping from parameter name to optimizer index for the current session
name_to_idx = {}
for name, param in self.model.named_parameters():
if ".expert" in name:
idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None)
if idx is not None:
name_to_idx[name] = idx
if args.save_opt_states:
optimizer_state = self.optimizer.state_dict()
# Create a mapping from parameter name to optimizer index for the current session
name_to_idx = {}
for name, param in self.model.named_parameters():
if ".expert" in name:
idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None)
if idx is not None:
name_to_idx[name] = idx
time.sleep(1)
# load expert state and optimizer state from all ranks
for rank in range(1, ep_size):
expert_path = os.path.join(output_dir, f"expert_state_{rank}.bin")
opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin")
max_retries = 3
for retry in range(max_retries):
try:
expert_state = torch.load(expert_path)
expert_opt_state = torch.load(opt_path)
if args.save_opt_states:
opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin")
expert_opt_state = torch.load(opt_path)
# Convert parameter names back to indices for the optimizer state
for param_name, state in expert_opt_state['state'].items():
if param_name in name_to_idx:
idx = name_to_idx[param_name]
optimizer_state['state'][idx] = state
# Update model state
original_state.update(expert_state)
# Convert parameter names back to indices for the optimizer state
for param_name, state in expert_opt_state['state'].items():
if param_name in name_to_idx:
idx = name_to_idx[param_name]
optimizer_state['state'][idx] = state
break
except Exception as e:
if retry == max_retries - 1:
@ -223,13 +228,17 @@ def main():
time.sleep(1)
original_save(output_dir, state_dict=original_state)
# Save complete optimizer state
opt_save_path = os.path.join(output_dir, "optimizer.pt")
torch.save(optimizer_state, opt_save_path)
# Save complete optimizer state if enabled
if args.save_opt_states:
opt_save_path = os.path.join(output_dir, "optimizer.pt")
torch.save(optimizer_state, opt_save_path)
# remove those intermediate .bin files
for rank in range(1, ep_size):
os.remove(os.path.join(output_dir, f"expert_state_{rank}.bin"))
os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin"))
if args.save_opt_states:
os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin"))
dist.barrier()
tokenizer.save_pretrained(output_dir)