import argparse import json import yaml import os import random import torch import torch.distributed as dist from types import MethodType from torch.utils.data import TensorDataset from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, logging from benchmarks import * from utils import get_formatted_input_and_target, get_examples_from_buffer_pad, init_parallel_groups from esft import to_esft from deepseek.modeling_deepseek import DeepseekV2ForCausalLM import time os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["NCCL_AVOID_RECORD_STREAMS"] = "1" logging.set_verbosity_error() def main(): parser = argparse.ArgumentParser() parser.add_argument("--base_model_path", type=str, required=True) parser.add_argument("--expert_config", type=str, required=True) parser.add_argument("--train_dataset", type=str, required=True) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--train_config", type=str, required=True) parser.add_argument("--wandb_api_key", type=str, required=False) args = parser.parse_args() expert_config = json.load(open(args.expert_config)) output_dir = args.output_dir base_model_path = args.base_model_path config = yaml.safe_load(open(args.train_config)) os.makedirs(args.output_dir, exist_ok=True) seed = config['seed'] torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) if args.wandb_api_key is not None: import wandb wandb.login(key=args.wandb_api_key) ep_size = config.get("ep_size", 1) world_size, local_rank, ep_group, edp_group = init_parallel_groups(ep_size) edp_size = world_size // ep_size # Prepare data tokenizer = AutoTokenizer.from_pretrained(base_model_path) samples = [json.loads(i) for i in open(f"datasets/train/{args.train_dataset}.jsonl").readlines()] buffer = [] for instance in samples: input_ids, target_ids = get_formatted_input_and_target(instance['messages'], tokenizer, -100) buffer.append((input_ids, target_ids)) seq_length = config['seq_length'] random_concat_ratio = config['random_concat_ratio'] concated_examples = get_examples_from_buffer_pad(buffer, seq_length, tokenizer, random_concat_ratio) dataset = TensorDataset(concated_examples['input_ids'], concated_examples['labels']) train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.98), len(dataset) - int(len(dataset) * 0.98)]) # Training arguments training_args = TrainingArguments( output_dir=output_dir, max_steps=config['steps'], per_device_train_batch_size=config['per_device_batch_size'], per_device_eval_batch_size=config['per_device_batch_size'], warmup_steps=config['warmup_steps'], weight_decay=config['weight_decay'], logging_dir=f"{output_dir}/logs", logging_steps=config['logging_steps'], save_steps=config['save_steps'], eval_strategy="steps", eval_steps=config['eval_steps'], gradient_accumulation_steps=config['gradient_accumulation_steps'], load_best_model_at_end=True, metric_for_best_model="loss", greater_is_better=False, bf16=True, lr_scheduler_type='constant', save_total_limit=5, learning_rate=config['learning_rate'], optim=config['optim'], adam_beta1=config['adam_beta1'], adam_beta2=config['adam_beta2'], disable_tqdm=False, gradient_checkpointing=config['gradient_checkpointing'], gradient_checkpointing_kwargs={"use_reentrant": False} if config['gradient_checkpointing'] else {}, # if set to True, backward will raise bug ) def data_collator(data): input_ids = torch.stack([item[0] for item in data]) labels = torch.stack([item[1] for item in data]) return {"input_ids": input_ids, "labels": labels} model = DeepseekV2ForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, ep_size=ep_size, attn_implementation="flash_attention_2") model._ddp_params_and_buffers_to_ignore = [n for n, _ in model.named_parameters() if ".expert" in n] # we manage grad synchronization of expert parameters to_esft(model, expert_config) model.dummy = torch.nn.Parameter(torch.zeros(1, dtype=model.dtype)) # prevent DDP from having no trainable parameters model._keys_to_ignore_on_save = ["dummy"] expert_params = [p for n, p in model.named_parameters() if p.requires_grad and ".expert" in n] for layer in model.model.layers: if type(layer.mlp).__name__ != "DeepseekV2MoE": continue layer.mlp.ep_group = ep_group # Force all2all backward the same number of times if ep_size > 1 and not expert_config["non_expert_modules"]: min_layer_id = min(int(k) for k, v in expert_config["experts"].items() if v) mlp = model.model.layers[min_layer_id].mlp forward = mlp.forward def custom_forward(self, hidden_states: torch.Tensor): return forward(hidden_states.requires_grad_(torch.is_grad_enabled())) mlp.forward = MethodType(custom_forward, mlp) # Initialize Trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=valid_dataset, data_collator=data_collator, ) original_save_model = trainer.save_model def custom_save_model(self, output_dir=None, _internal_call=False): if output_dir is None: output_dir = self.args.output_dir # Ensure all ranks participate in saving self._save(output_dir) dist.barrier() trainer.save_model = MethodType(custom_save_model, trainer) original_save = trainer._save def custom_save(self, output_dir=None, state_dict=None): ep_rank = ep_group.rank() edp_rank = edp_group.rank() os.makedirs(output_dir, exist_ok=True) if local_rank < ep_size and edp_rank == 0: # Save expert model state expert_state = {k: v for k, v in self.model.state_dict().items() if ".expert" in k} expert_save_path = os.path.join(output_dir, f"expert_state_{ep_rank}.bin") # Save expert optimizer state using parameter names instead of ids optimizer = self.optimizer opt_state_dict = optimizer.state_dict() # Create a mapping from parameter id to parameter name id_to_name = {} for name, param in self.model.named_parameters(): if ".expert" in name: id_to_name[id(param)] = name # Get the mapping from optimizer state index to parameter param_to_idx = {param: idx for idx, param in enumerate(optimizer.param_groups[0]['params'], 1)} # Save optimizer state using parameter names as keys expert_opt_state = {'state': {}, 'param_groups': opt_state_dict['param_groups']} for param, idx in param_to_idx.items(): if id(param) in id_to_name: param_name = id_to_name[id(param)] if idx in opt_state_dict['state']: expert_opt_state['state'][param_name] = opt_state_dict['state'][idx] expert_opt_path = os.path.join(output_dir, f"expert_optimizer_{ep_rank}.bin") # Save both states atomically temp_expert_path = expert_save_path + ".tmp" temp_opt_path = expert_opt_path + ".tmp" torch.save(expert_state, temp_expert_path) torch.save(expert_opt_state, temp_opt_path) os.sync() os.replace(temp_expert_path, expert_save_path) os.replace(temp_opt_path, expert_opt_path) dist.barrier() if local_rank == 0: original_state = self.model.state_dict() optimizer_state = self.optimizer.state_dict() # Create a mapping from parameter name to optimizer index for the current session name_to_idx = {} for name, param in self.model.named_parameters(): if ".expert" in name: idx = next((i for i, p in enumerate(self.optimizer.param_groups[0]['params'], 1) if id(p) == id(param)), None) if idx is not None: name_to_idx[name] = idx time.sleep(1) for rank in range(1, ep_size): expert_path = os.path.join(output_dir, f"expert_state_{rank}.bin") opt_path = os.path.join(output_dir, f"expert_optimizer_{rank}.bin") max_retries = 3 for retry in range(max_retries): try: expert_state = torch.load(expert_path) expert_opt_state = torch.load(opt_path) # Update model state original_state.update(expert_state) # Convert parameter names back to indices for the optimizer state for param_name, state in expert_opt_state['state'].items(): if param_name in name_to_idx: idx = name_to_idx[param_name] optimizer_state['state'][idx] = state break except Exception as e: if retry == max_retries - 1: raise time.sleep(1) original_save(output_dir, state_dict=original_state) # Save complete optimizer state opt_save_path = os.path.join(output_dir, "optimizer.pt") torch.save(optimizer_state, opt_save_path) # remove those intermediate .bin files for rank in range(1, ep_size): os.remove(os.path.join(output_dir, f"expert_state_{rank}.bin")) os.remove(os.path.join(output_dir, f"expert_optimizer_{rank}.bin")) dist.barrier() tokenizer.save_pretrained(output_dir) trainer._save = MethodType(custom_save, trainer) accelerator = trainer.accelerator backward = accelerator.backward def custom_backward(self, loss, **kwargs): backward(loss, **kwargs) if not self.sync_gradients or edp_size == 1: return for p in expert_params: g = p.grad if p.grad is not None else torch.zeros_like(p) dist.all_reduce(g, op=dist.ReduceOp.AVG, group=edp_group) if p.grad is not g: p.grad = g accelerator.backward = MethodType(custom_backward, accelerator) trainer.train() print("Training complete") if __name__ == "__main__": main()