mirror of
https://github.com/deepseek-ai/ESFT.git
synced 2025-05-23 02:37:07 -04:00
streamline code
This commit is contained in:
parent
f38f67706c
commit
3746ca7441
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,4 +1,5 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
wandb
|
wandb
|
||||||
all_models/
|
all_models/
|
||||||
results/checkpoints/
|
results/checkpoints/
|
||||||
|
results/expert_scores/*/rank_*
|
@ -92,10 +92,11 @@ python train.py \
|
|||||||
|
|
||||||
torchrun --nproc-per-node=8 train_ep.py \
|
torchrun --nproc-per-node=8 train_ep.py \
|
||||||
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
--base_model_path=deepseek-ai/ESFT-vanilla-lite \
|
||||||
--expert_config=results/expert_configs/translation.json \
|
--expert_config=results/expert_configs/intent.json \
|
||||||
--train_dataset=translation \
|
--train_dataset=intent \
|
||||||
|
--save_opt_states \
|
||||||
--train_config=configs/base.yaml \
|
--train_config=configs/base.yaml \
|
||||||
--output_dir=results/checkpoints/translation
|
--output_dir=results/checkpoints/test/eval_intent
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ n_device: 8 # Number of devices
|
|||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
optim: adamw_torch_fused
|
optim: adamw_torch_fused
|
||||||
steps: 1 # Number of training steps
|
steps: 500 # Number of training steps
|
||||||
learning_rate: 0.00001 # Learning rate
|
learning_rate: 0.00001 # Learning rate
|
||||||
weight_decay: 0.1 # Weight decay for optimizer
|
weight_decay: 0.1 # Weight decay for optimizer
|
||||||
warmup_steps: 0 # Number of warmup steps for learning rate scheduler
|
warmup_steps: 0 # Number of warmup steps for learning rate scheduler
|
||||||
@ -19,8 +19,8 @@ random_concat_ratio: 0.2 # Ratio of random concatenation
|
|||||||
|
|
||||||
|
|
||||||
# Evaluation settings
|
# Evaluation settings
|
||||||
eval_steps: 1 # Evaluate every X steps
|
eval_steps: 100 # Evaluate every X steps
|
||||||
save_steps: 1 # Save model every X steps
|
save_steps: 100 # Save model every X steps
|
||||||
|
|
||||||
# Tokenizer settings
|
# Tokenizer settings
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ torchrun --nproc-per-node=8 train_ep.py \
|
|||||||
--base_model_path=${base_model_path} \
|
--base_model_path=${base_model_path} \
|
||||||
--expert_config=results/expert_configs/intent.json \
|
--expert_config=results/expert_configs/intent.json \
|
||||||
--train_dataset=intent \
|
--train_dataset=intent \
|
||||||
|
--save_opt_states \
|
||||||
--train_config=configs/base.yaml \
|
--train_config=configs/base.yaml \
|
||||||
--output_dir=results/checkpoints/${exp_name}
|
--output_dir=results/checkpoints/${exp_name}
|
||||||
|
|
||||||
|
105
train_ep.py
105
train_ep.py
@ -28,6 +28,7 @@ def main():
|
|||||||
parser.add_argument("--train_dataset", type=str, required=True)
|
parser.add_argument("--train_dataset", type=str, required=True)
|
||||||
parser.add_argument("--output_dir", type=str, required=True)
|
parser.add_argument("--output_dir", type=str, required=True)
|
||||||
parser.add_argument("--train_config", type=str, required=True)
|
parser.add_argument("--train_config", type=str, required=True)
|
||||||
|
parser.add_argument("--save_opt_states", action="store_true", help="Whether to save optimizer states")
|
||||||
parser.add_argument("--wandb_api_key", type=str, required=False)
|
parser.add_argument("--wandb_api_key", type=str, required=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -145,77 +146,81 @@ def main():
|
|||||||
edp_rank = edp_group.rank()
|
edp_rank = edp_group.rank()
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
if local_rank < ep_size and edp_rank == 0:
|
if ep_rank > 0 and edp_rank == 0:
|
||||||
# Save expert model state
|
# Save expert model state
|
||||||
expert_state = {k: v for k, v in self.model.state_dict().items() if ".expert" in k}
|
expert_state = {k: v for k, v in self.model.state_dict().items() if ".expert" in k}
|
||||||
expert_save_path = os.path.join(output_dir, f"expert_state_{ep_rank}.bin")
|
expert_save_path = os.path.join(output_dir, f"expert_state_{ep_rank}.bin")
|
||||||
|
|
||||||
# Save expert optimizer state using parameter names instead of ids
|
# Save expert optimizer state only if save_opt_states is True
|
||||||
optimizer = self.optimizer
|
if args.save_opt_states:
|
||||||
opt_state_dict = optimizer.state_dict()
|
optimizer = self.optimizer
|
||||||
|
opt_state_dict = optimizer.state_dict()
|
||||||
|
|
||||||
|
# Create a mapping from parameter id to parameter name
|
||||||
|
id_to_name = {}
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
if ".expert" in name:
|
||||||
|
id_to_name[id(param)] = name
|
||||||
|
|
||||||
|
# Get the mapping from optimizer state index to parameter
|
||||||
|
param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)}
|
||||||
|
|
||||||
|
# Save optimizer state using parameter names as keys
|
||||||
|
expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']}
|
||||||
|
for param, idx in param_to_idx.items():
|
||||||
|
if id(param) in id_to_name:
|
||||||
|
param_name = id_to_name[id(param)]
|
||||||
|
if idx in opt_state_dict['state']:
|
||||||
|
expert_opt_state['state'][param_name] = opt_state_dict['state'][idx]
|
||||||
|
|
||||||
|
expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin")
|
||||||
|
# Save optimizer state
|
||||||
|
temp_opt_path = expert_opt_path + ".tmp"
|
||||||
|
torch.save(expert_opt_state, temp_opt_path)
|
||||||
|
os.replace(temp_opt_path, expert_opt_path)
|
||||||
|
|
||||||
# Create a mapping from parameter id to parameter name
|
# Save model state
|
||||||
id_to_name = {}
|
|
||||||
for name, param in self.model.named_parameters():
|
|
||||||
if ".expert" in name:
|
|
||||||
id_to_name[id(param)] = name
|
|
||||||
|
|
||||||
# Get the mapping from optimizer state index to parameter
|
|
||||||
param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)}
|
|
||||||
|
|
||||||
# Save optimizer state using parameter names as keys
|
|
||||||
expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']}
|
|
||||||
for param, idx in param_to_idx.items():
|
|
||||||
if id(param) in id_to_name:
|
|
||||||
param_name = id_to_name[id(param)]
|
|
||||||
if idx in opt_state_dict['state']:
|
|
||||||
expert_opt_state['state'][param_name] = opt_state_dict['state'][idx]
|
|
||||||
|
|
||||||
expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin")
|
|
||||||
|
|
||||||
# Save both states atomically
|
|
||||||
temp_expert_path = expert_save_path + ".tmp"
|
temp_expert_path = expert_save_path + ".tmp"
|
||||||
temp_opt_path = expert_opt_path + ".tmp"
|
|
||||||
torch.save(expert_state, temp_expert_path)
|
torch.save(expert_state, temp_expert_path)
|
||||||
torch.save(expert_opt_state, temp_opt_path)
|
|
||||||
os.sync()
|
os.sync()
|
||||||
os.replace(temp_expert_path, expert_save_path)
|
os.replace(temp_expert_path, expert_save_path)
|
||||||
os.replace(temp_opt_path, expert_opt_path)
|
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
original_state = self.model.state_dict()
|
original_state = self.model.state_dict()
|
||||||
optimizer_state = self.optimizer.state_dict()
|
|
||||||
|
|
||||||
# Create a mapping from parameter name to optimizer index for the current session
|
if args.save_opt_states:
|
||||||
name_to_idx = {}
|
optimizer_state = self.optimizer.state_dict()
|
||||||
for name, param in self.model.named_parameters():
|
# Create a mapping from parameter name to optimizer index for the current session
|
||||||
if ".expert" in name:
|
name_to_idx = {}
|
||||||
idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None)
|
for name, param in self.model.named_parameters():
|
||||||
if idx is not None:
|
if ".expert" in name:
|
||||||
name_to_idx[name] = idx
|
idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None)
|
||||||
|
if idx is not None:
|
||||||
|
name_to_idx[name] = idx
|
||||||
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
# load expert state and optimizer state from all ranks
|
||||||
for rank in range(1, ep_size):
|
for rank in range(1, ep_size):
|
||||||
expert_path = os.path.join(output_dir, f"expert_state_{rank}.bin")
|
expert_path = os.path.join(output_dir, f"expert_state_{rank}.bin")
|
||||||
opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin")
|
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
for retry in range(max_retries):
|
for retry in range(max_retries):
|
||||||
try:
|
try:
|
||||||
expert_state = torch.load(expert_path)
|
expert_state = torch.load(expert_path)
|
||||||
expert_opt_state = torch.load(opt_path)
|
if args.save_opt_states:
|
||||||
|
opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin")
|
||||||
|
expert_opt_state = torch.load(opt_path)
|
||||||
|
|
||||||
|
# Convert parameter names back to indices for the optimizer state
|
||||||
|
for param_name, state in expert_opt_state['state'].items():
|
||||||
|
if param_name in name_to_idx:
|
||||||
|
idx = name_to_idx[param_name]
|
||||||
|
optimizer_state['state'][idx] = state
|
||||||
|
|
||||||
# Update model state
|
# Update model state
|
||||||
original_state.update(expert_state)
|
original_state.update(expert_state)
|
||||||
|
|
||||||
# Convert parameter names back to indices for the optimizer state
|
|
||||||
for param_name, state in expert_opt_state['state'].items():
|
|
||||||
if param_name in name_to_idx:
|
|
||||||
idx = name_to_idx[param_name]
|
|
||||||
optimizer_state['state'][idx] = state
|
|
||||||
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry == max_retries - 1:
|
if retry == max_retries - 1:
|
||||||
@ -223,13 +228,17 @@ def main():
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
original_save(output_dir, state_dict=original_state)
|
original_save(output_dir, state_dict=original_state)
|
||||||
# Save complete optimizer state
|
|
||||||
opt_save_path = os.path.join(output_dir, "optimizer.pt")
|
# Save complete optimizer state if enabled
|
||||||
torch.save(optimizer_state, opt_save_path)
|
if args.save_opt_states:
|
||||||
|
opt_save_path = os.path.join(output_dir, "optimizer.pt")
|
||||||
|
torch.save(optimizer_state, opt_save_path)
|
||||||
|
|
||||||
# remove those intermediate .bin files
|
# remove those intermediate .bin files
|
||||||
for rank in range(1, ep_size):
|
for rank in range(1, ep_size):
|
||||||
os.remove(os.path.join(output_dir, f"expert_state_{rank}.bin"))
|
os.remove(os.path.join(output_dir, f"expert_state_{rank}.bin"))
|
||||||
os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin"))
|
if args.save_opt_states:
|
||||||
|
os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin"))
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
tokenizer.save_pretrained(output_dir)
|
tokenizer.save_pretrained(output_dir)
|
||||||
|
Loading…
Reference in New Issue
Block a user