import os import torch import torch.nn as nn import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data.dataset import Dataset import torch.optim as optim from transformers import get_cosine_schedule_with_warmup import wandb import json import time import math import random import numpy as np from typing import Dict, List, Optional import argparse from pathlib import Path import tiktoken from model import DeepSeekMathConfig, DeepSeekMathForCausalLM class MathDataset(Dataset): """Dataset for mathematical text data""" def __init__(self, data_path: str, tokenizer, max_length: int = 4096, split: str = "train"): self.tokenizer = tokenizer self.max_length = max_length self.data = [] # Load mathematical datasets # This would include sources like: # - Mathematical papers (arXiv) # - Mathematical problem-solution pairs # - Formal mathematical proofs # - Mathematical textbooks # - Code for mathematical computations if os.path.exists(data_path): with open(data_path, 'r', encoding='utf-8') as f: for line in f: try: sample = json.loads(line) if 'text' in sample: self.data.append(sample['text']) except json.JSONDecodeError: continue else: # Dummy data for demonstration self.data = self._generate_dummy_math_data() def _generate_dummy_math_data(self): """Generate dummy mathematical data for demonstration""" dummy_data = [ "Theorem: For any real numbers a and b, (a + b)² = a² + 2ab + b². Proof: (a + b)² = (a + b)(a + b) = a(a + b) + b(a + b) = a² + ab + ba + b² = a² + 2ab + b².", "Problem: Solve the equation 2x + 5 = 13. Solution: 2x + 5 = 13, 2x = 13 - 5, 2x = 8, x = 4.", "Definition: A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.", "Lemma: If n is a composite number, then n has a prime divisor less than or equal to √n.", "Calculate the derivative of f(x) = x³ + 2x² - 5x + 1. f'(x) = 3x² + 4x - 5.", ] * 1000 # Repeat for more data return dummy_data def __len__(self): return len(self.data) def __getitem__(self, idx): text = self.data[idx] # Tokenize the text tokens = self.tokenizer.encode(text) # Truncate or pad to max_length if len(tokens) > self.max_length: tokens = tokens[:self.max_length] else: # Pad with tokenizer's pad token if available, otherwise use 0 pad_token = getattr(self.tokenizer, 'pad_token_id', 0) tokens.extend([pad_token] * (self.max_length - len(tokens))) return { 'input_ids': torch.tensor(tokens, dtype=torch.long), 'labels': torch.tensor(tokens, dtype=torch.long) } class DeepSeekMathTrainer: def __init__(self, config: Dict): self.config = config self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.setup_distributed_training() self.setup_model() self.setup_tokenizer() self.setup_datasets() self.setup_optimizer() self.setup_logging() def setup_distributed_training(self): """Setup distributed training if available""" self.is_distributed = False if 'WORLD_SIZE' in os.environ: self.is_distributed = True self.world_size = int(os.environ['WORLD_SIZE']) self.rank = int(os.environ['RANK']) self.local_rank = int(os.environ['LOCAL_RANK']) # Initialize distributed training dist.init_process_group(backend='nccl') torch.cuda.set_device(self.local_rank) self.device = torch.device(f'cuda:{self.local_rank}') else: self.world_size = 1 self.rank = 0 self.local_rank = 0 def setup_model(self): """Initialize the DeepSeek-Math model""" model_config = DeepSeekMathConfig( vocab_size=self.config['vocab_size'], hidden_size=self.config['hidden_size'], intermediate_size=self.config['intermediate_size'], num_hidden_layers=self.config['num_hidden_layers'], num_attention_heads=self.config['num_attention_heads'], max_position_embeddings=self.config['max_position_embeddings'], rms_norm_eps=self.config['rms_norm_eps'], rope_theta=self.config['rope_theta'], ) self.model = DeepSeekMathForCausalLM(model_config) self.model = self.model.to(self.device) if self.is_distributed: self.model = DDP(self.model, device_ids=[self.local_rank]) # Print model statistics if self.rank == 0: total_params = sum(p.numel() for p in self.model.parameters()) trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") def setup_tokenizer(self): """Setup tokenizer (using tiktoken for GPT-4 tokenizer)""" try: self.tokenizer = tiktoken.get_encoding("cl100k_base") except: # Fallback to a simple character-level tokenizer self.tokenizer = SimpleTokenizer(vocab_size=self.config['vocab_size']) def setup_datasets(self): """Setup training and validation datasets""" train_dataset = MathDataset( data_path=self.config['train_data_path'], tokenizer=self.tokenizer, max_length=self.config['max_length'], split='train' ) val_dataset = MathDataset( data_path=self.config['val_data_path'], tokenizer=self.tokenizer, max_length=self.config['max_length'], split='val' ) # Setup distributed samplers train_sampler = DistributedSampler(train_dataset) if self.is_distributed else None val_sampler = DistributedSampler(val_dataset) if self.is_distributed else None self.train_dataloader = DataLoader( train_dataset, batch_size=self.config['batch_size'], sampler=train_sampler, shuffle=(train_sampler is None), num_workers=self.config['num_workers'], pin_memory=True ) self.val_dataloader = DataLoader( val_dataset, batch_size=self.config['batch_size'], sampler=val_sampler, shuffle=False, num_workers=self.config['num_workers'], pin_memory=True ) def setup_optimizer(self): """Setup optimizer and learning rate scheduler""" # AdamW optimizer with weight decay no_decay = ['bias', 'LayerNorm.weight', 'layernorm.weight'] optimizer_grouped_parameters = [ { 'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': self.config['weight_decay'] }, { 'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 } ] self.optimizer = optim.AdamW( optimizer_grouped_parameters, lr=self.config['learning_rate'], betas=(self.config['adam_beta1'], self.config['adam_beta2']), eps=self.config['adam_epsilon'] ) # Calculate total training steps self.total_steps = len(self.train_dataloader) * self.config['num_epochs'] self.warmup_steps = int(self.total_steps * self.config['warmup_ratio']) # Cosine learning rate scheduler with warmup self.scheduler = get_cosine_schedule_with_warmup( self.optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.total_steps ) def setup_logging(self): """Setup logging with wandb""" if self.rank == 0: wandb.init( project=self.config['project_name'], name=self.config['run_name'], config=self.config ) def save_checkpoint(self, epoch: int, step: int, loss: float): """Save model checkpoint""" if self.rank == 0: checkpoint_dir = Path(self.config['checkpoint_dir']) checkpoint_dir.mkdir(parents=True, exist_ok=True) model_to_save = self.model.module if hasattr(self.model, 'module') else self.model checkpoint = { 'epoch': epoch, 'step': step, 'model_state_dict': model_to_save.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'loss': loss, 'config': self.config } checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}_step_{step}.pt' torch.save(checkpoint, checkpoint_path) # Keep only the last N checkpoints checkpoints = sorted(checkpoint_dir.glob('checkpoint_*.pt')) if len(checkpoints) > self.config['max_checkpoints']: for old_checkpoint in checkpoints[:-self.config['max_checkpoints']]: old_checkpoint.unlink() print(f"Checkpoint saved: {checkpoint_path}") def load_checkpoint(self, checkpoint_path: str): """Load model checkpoint""" checkpoint = torch.load(checkpoint_path, map_location=self.device) model_to_load = self.model.module if hasattr(self.model, 'module') else self.model model_to_load.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) return checkpoint['epoch'], checkpoint['step'], checkpoint['loss'] def compute_loss(self, batch): """Compute the loss for a batch""" input_ids = batch['input_ids'].to(self.device) labels = batch['labels'].to(self.device) outputs = self.model(input_ids=input_ids, labels=labels) return outputs['loss'] def train_epoch(self, epoch: int): """Train for one epoch""" self.model.train() total_loss = 0 num_batches = len(self.train_dataloader) if self.is_distributed: self.train_dataloader.sampler.set_epoch(epoch) for step, batch in enumerate(self.train_dataloader): # Forward pass loss = self.compute_loss(batch) # Backward pass loss.backward() # Gradient clipping if self.config['max_grad_norm'] > 0: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config['max_grad_norm'] ) # Optimizer step self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() total_loss += loss.item() # Logging if step % self.config['log_interval'] == 0 and self.rank == 0: avg_loss = total_loss / (step + 1) lr = self.scheduler.get_last_lr()[0] print(f"Epoch {epoch}, Step {step}/{num_batches}, " f"Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}, " f"LR: {lr:.2e}") wandb.log({ 'train/loss': loss.item(), 'train/avg_loss': avg_loss, 'train/learning_rate': lr, 'train/epoch': epoch, 'train/step': step }) # Save checkpoint if step % self.config['save_interval'] == 0 and step > 0: self.save_checkpoint(epoch, step, loss.item()) return total_loss / num_batches def validate(self, epoch: int): """Validate the model""" self.model.eval() total_loss = 0 num_batches = len(self.val_dataloader) with torch.no_grad(): for step, batch in enumerate(self.val_dataloader): loss = self.compute_loss(batch) total_loss += loss.item() avg_loss = total_loss / num_batches perplexity = math.exp(avg_loss) if self.rank == 0: print(f"Validation - Epoch {epoch}, Loss: {avg_loss:.4f}, " f"Perplexity: {perplexity:.2f}") wandb.log({ 'val/loss': avg_loss, 'val/perplexity': perplexity, 'val/epoch': epoch }) return avg_loss def train(self): """Main training loop""" print("Starting training...") best_val_loss = float('inf') for epoch in range(self.config['num_epochs']): print(f"\n=== Epoch {epoch + 1}/{self.config['num_epochs']} ===") # Training train_loss = self.train_epoch(epoch) # Validation val_loss = self.validate(epoch) # Save best model if val_loss < best_val_loss and self.rank == 0: best_val_loss = val_loss best_model_path = Path(self.config['checkpoint_dir']) / 'best_model.pt' model_to_save = self.model.module if hasattr(self.model, 'module') else self.model torch.save(model_to_save.state_dict(), best_model_path) print(f"Best model saved with validation loss: {best_val_loss:.4f}") # Save epoch checkpoint self.save_checkpoint(epoch, len(self.train_dataloader), train_loss) if self.rank == 0: print("Training completed!") wandb.finish() class SimpleTokenizer: """Simple character-level tokenizer as fallback""" def __init__(self, vocab_size: int = 50000): self.vocab_size = vocab_size self.pad_token_id = 0 def encode(self, text: str) -> List[int]: # Simple character-level encoding return [min(ord(c), self.vocab_size - 1) for c in text] def parse_args(): parser = argparse.ArgumentParser(description='DeepSeek-Math Pretraining') # Model configuration parser.add_argument('--vocab_size', type=int, default=102400) parser.add_argument('--hidden_size', type=int, default=4096) parser.add_argument('--intermediate_size', type=int, default=11008) parser.add_argument('--num_hidden_layers', type=int, default=30) parser.add_argument('--num_attention_heads', type=int, default=32) parser.add_argument('--max_position_embeddings', type=int, default=4096) parser.add_argument('--max_length', type=int, default=2048) # Training configuration parser.add_argument('--batch_size', type=int, default=8) parser.add_argument('--num_epochs', type=int, default=3) parser.add_argument('--learning_rate', type=float, default=1e-4) parser.add_argument('--weight_decay', type=float, default=0.1) parser.add_argument('--adam_beta1', type=float, default=0.9) parser.add_argument('--adam_beta2', type=float, default=0.95) parser.add_argument('--adam_epsilon', type=float, default=1e-8) parser.add_argument('--warmup_ratio', type=float, default=0.03) parser.add_argument('--max_grad_norm', type=float, default=1.0) # Data paths parser.add_argument('--train_data_path', type=str, default='data/train.jsonl') parser.add_argument('--val_data_path', type=str, default='data/val.jsonl') # Logging and checkpointing parser.add_argument('--checkpoint_dir', type=str, default='checkpoints') parser.add_argument('--log_interval', type=int, default=100) parser.add_argument('--save_interval', type=int, default=1000) parser.add_argument('--max_checkpoints', type=int, default=5) parser.add_argument('--project_name', type=str, default='deepseek-math') parser.add_argument('--run_name', type=str, default='pretraining') # System configuration parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--rms_norm_eps', type=float, default=1e-6) parser.add_argument('--rope_theta', type=float, default=10000.0) return parser.parse_args() def set_seed(seed: int = 42): """Set random seed for reproducibility""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def main(): args = parse_args() set_seed(42) # Convert args to config dict config = vars(args) # Initialize trainer trainer = DeepSeekMathTrainer(config) # Start training trainer.train() if __name__ == "__main__": main() # Example usage with distributed training: # torchrun --nproc_per_node=8 --nnodes=1 deepseek_math_training.py \ # --batch_size 4 \ # --learning_rate 1e-4 \ # --num_epochs 3 \ # --train_data_path /path/to/train.jsonl \ # --val_data_path /path/to/val.jsonl \ # --checkpoint_dir ./checkpoints \ # --project_name deepseek-math-7b \ # --run_name pretraining_run_1 # Data preprocessing script for mathematical datasets def preprocess_math_data(): """ Example preprocessing script for mathematical datasets This would typically process: - ArXiv papers in mathematics - Mathematical problem-solution pairs - Formal proofs - Mathematical textbooks - Code for mathematical computations """ import re from pathlib import Path def clean_math_text(text: str) -> str: """Clean and normalize mathematical text""" # Remove excessive whitespace text = re.sub(r'\s+', ' ', text) # Normalize mathematical notation text = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', r'(\1)/(\2)', text) text = re.sub(r'\\sqrt\{([^}]+)\}', r'sqrt(\1)', text) # Clean up common LaTeX commands text = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', text) text = re.sub(r'\\[a-zA-Z]+', '', text) return text.strip() def process_file(input_path: str, output_path: str): """Process a single file and save cleaned data""" with open(input_path, 'r', encoding='utf-8') as f: content = f.read() # Split into chunks (could be paragraphs, sections, etc.) chunks = content.split('\n\n') processed_data = [] for chunk in chunks: if len(chunk.strip()) > 50: # Filter out very short chunks cleaned = clean_math_text(chunk) if cleaned: processed_data.append({'text': cleaned}) # Save as JSONL with open(output_path, 'w', encoding='utf-8') as f: for item in processed_data: f.write(json.dumps(item) + '\n') # Example usage input_dir = Path('raw_data') output_dir = Path('processed_data') output_dir.mkdir(exist_ok=True) for file_path in input_dir.glob('*.txt'): output_path = output_dir / f"{file_path.stem}.jsonl" process_file(str(file_path), str(output_path)) print(f"Processed {file_path} -> {output_path}") if __name__ == "__main__": # Uncomment to run data preprocessing # preprocess_math_data() # Run main training main()