Implement DeepSeek-Math model and training pipeline with dataset handling and distributed training support

This commit is contained in:
jayeshthk 2025-06-19 17:27:05 +05:30
parent b8b0f8ce09
commit b53e984052
2 changed files with 965 additions and 0 deletions

428
train/model.py Normal file
View File

@ -0,0 +1,428 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
from dataclasses import dataclass
@dataclass
class DeepSeekMathConfig:
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32 # For grouped query attention
max_position_embeddings: int = 4096
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
attention_dropout: float = 0.0
hidden_dropout: float = 0.0
use_cache: bool = True
rope_scaling: Optional[dict] = None
tie_word_embeddings: bool = False
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, seq_len=None):
if seq_len is None:
seq_len = x.shape[-2]
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class DeepSeekMathAttention(nn.Module):
def __init__(self, config: DeepSeekMathConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = RotaryPositionalEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Repeat k/v heads if n_kv_heads < n_heads
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class DeepSeekMathMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
# SwiGLU activation function
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class DeepSeekMathDecoderLayer(nn.Module):
def __init__(self, config: DeepSeekMathConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = DeepSeekMathAttention(config=config)
self.mlp = DeepSeekMathMLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class DeepSeekMathModel(nn.Module):
def __init__(self, config: DeepSeekMathConfig):
super().__init__()
self.config = config
self.padding_idx = config.vocab_size - 1
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([DeepSeekMathDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# Attention mask
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
return {
"last_hidden_state": hidden_states,
"past_key_values": next_cache,
"hidden_states": all_hidden_states,
"attentions": all_self_attns,
}
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# Create causal mask
batch_size, seq_length = input_shape
causal_mask = torch.full((seq_length, seq_length), fill_value=float("-inf"), device=inputs_embeds.device)
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask = causal_mask.to(inputs_embeds.dtype)
if past_key_values_length > 0:
causal_mask = torch.cat(
[torch.zeros(seq_length, past_key_values_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device),
causal_mask], dim=-1
)
expanded_attn_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length + past_key_values_length)
expanded_attn_mask = expanded_attn_mask.to(inputs_embeds.dtype)
expanded_attn_mask = (1.0 - expanded_attn_mask) * torch.finfo(inputs_embeds.dtype).min
return expanded_attn_mask + causal_mask
class DeepSeekMathForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.model = DeepSeekMathModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = outputs["last_hidden_state"]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return {
"loss": loss,
"logits": logits,
"past_key_values": outputs["past_key_values"],
"hidden_states": outputs["hidden_states"],
"attentions": outputs["attentions"],
}
# config = DeepSeekMathConfig()
# model = DeepSeekMathForCausalLM(config)
# # Print model info
# total_params = sum(p.numel() for p in model.parameters())
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"Total parameters: {total_params:,}")
# print(f"Trainable parameters: {trainable_params:,}")
# print(f"Model size: ~{total_params * 4 / 1e9:.1f}B parameters")

537
train/train.py Normal file
View File

@ -0,0 +1,537 @@
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data.dataset import Dataset
import torch.optim as optim
from transformers import get_cosine_schedule_with_warmup
import wandb
import json
import time
import math
import random
import numpy as np
from typing import Dict, List, Optional
import argparse
from pathlib import Path
import tiktoken
from model import DeepSeekMathConfig, DeepSeekMathForCausalLM
class MathDataset(Dataset):
"""Dataset for mathematical text data"""
def __init__(self, data_path: str, tokenizer, max_length: int = 4096, split: str = "train"):
self.tokenizer = tokenizer
self.max_length = max_length
self.data = []
# Load mathematical datasets
# This would include sources like:
# - Mathematical papers (arXiv)
# - Mathematical problem-solution pairs
# - Formal mathematical proofs
# - Mathematical textbooks
# - Code for mathematical computations
if os.path.exists(data_path):
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
try:
sample = json.loads(line)
if 'text' in sample:
self.data.append(sample['text'])
except json.JSONDecodeError:
continue
else:
# Dummy data for demonstration
self.data = self._generate_dummy_math_data()
def _generate_dummy_math_data(self):
"""Generate dummy mathematical data for demonstration"""
dummy_data = [
"Theorem: For any real numbers a and b, (a + b)² = a² + 2ab + b². Proof: (a + b)² = (a + b)(a + b) = a(a + b) + b(a + b) = a² + ab + ba + b² = a² + 2ab + b².",
"Problem: Solve the equation 2x + 5 = 13. Solution: 2x + 5 = 13, 2x = 13 - 5, 2x = 8, x = 4.",
"Definition: A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.",
"Lemma: If n is a composite number, then n has a prime divisor less than or equal to √n.",
"Calculate the derivative of f(x) = x³ + 2x² - 5x + 1. f'(x) = 3x² + 4x - 5.",
] * 1000 # Repeat for more data
return dummy_data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data[idx]
# Tokenize the text
tokens = self.tokenizer.encode(text)
# Truncate or pad to max_length
if len(tokens) > self.max_length:
tokens = tokens[:self.max_length]
else:
# Pad with tokenizer's pad token if available, otherwise use 0
pad_token = getattr(self.tokenizer, 'pad_token_id', 0)
tokens.extend([pad_token] * (self.max_length - len(tokens)))
return {
'input_ids': torch.tensor(tokens, dtype=torch.long),
'labels': torch.tensor(tokens, dtype=torch.long)
}
class DeepSeekMathTrainer:
def __init__(self, config: Dict):
self.config = config
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.setup_distributed_training()
self.setup_model()
self.setup_tokenizer()
self.setup_datasets()
self.setup_optimizer()
self.setup_logging()
def setup_distributed_training(self):
"""Setup distributed training if available"""
self.is_distributed = False
if 'WORLD_SIZE' in os.environ:
self.is_distributed = True
self.world_size = int(os.environ['WORLD_SIZE'])
self.rank = int(os.environ['RANK'])
self.local_rank = int(os.environ['LOCAL_RANK'])
# Initialize distributed training
dist.init_process_group(backend='nccl')
torch.cuda.set_device(self.local_rank)
self.device = torch.device(f'cuda:{self.local_rank}')
else:
self.world_size = 1
self.rank = 0
self.local_rank = 0
def setup_model(self):
"""Initialize the DeepSeek-Math model"""
model_config = DeepSeekMathConfig(
vocab_size=self.config['vocab_size'],
hidden_size=self.config['hidden_size'],
intermediate_size=self.config['intermediate_size'],
num_hidden_layers=self.config['num_hidden_layers'],
num_attention_heads=self.config['num_attention_heads'],
max_position_embeddings=self.config['max_position_embeddings'],
rms_norm_eps=self.config['rms_norm_eps'],
rope_theta=self.config['rope_theta'],
)
self.model = DeepSeekMathForCausalLM(model_config)
self.model = self.model.to(self.device)
if self.is_distributed:
self.model = DDP(self.model, device_ids=[self.local_rank])
# Print model statistics
if self.rank == 0:
total_params = sum(p.numel() for p in self.model.parameters())
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
def setup_tokenizer(self):
"""Setup tokenizer (using tiktoken for GPT-4 tokenizer)"""
try:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
except:
# Fallback to a simple character-level tokenizer
self.tokenizer = SimpleTokenizer(vocab_size=self.config['vocab_size'])
def setup_datasets(self):
"""Setup training and validation datasets"""
train_dataset = MathDataset(
data_path=self.config['train_data_path'],
tokenizer=self.tokenizer,
max_length=self.config['max_length'],
split='train'
)
val_dataset = MathDataset(
data_path=self.config['val_data_path'],
tokenizer=self.tokenizer,
max_length=self.config['max_length'],
split='val'
)
# Setup distributed samplers
train_sampler = DistributedSampler(train_dataset) if self.is_distributed else None
val_sampler = DistributedSampler(val_dataset) if self.is_distributed else None
self.train_dataloader = DataLoader(
train_dataset,
batch_size=self.config['batch_size'],
sampler=train_sampler,
shuffle=(train_sampler is None),
num_workers=self.config['num_workers'],
pin_memory=True
)
self.val_dataloader = DataLoader(
val_dataset,
batch_size=self.config['batch_size'],
sampler=val_sampler,
shuffle=False,
num_workers=self.config['num_workers'],
pin_memory=True
)
def setup_optimizer(self):
"""Setup optimizer and learning rate scheduler"""
# AdamW optimizer with weight decay
no_decay = ['bias', 'LayerNorm.weight', 'layernorm.weight']
optimizer_grouped_parameters = [
{
'params': [p for n, p in self.model.named_parameters()
if not any(nd in n for nd in no_decay)],
'weight_decay': self.config['weight_decay']
},
{
'params': [p for n, p in self.model.named_parameters()
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0
}
]
self.optimizer = optim.AdamW(
optimizer_grouped_parameters,
lr=self.config['learning_rate'],
betas=(self.config['adam_beta1'], self.config['adam_beta2']),
eps=self.config['adam_epsilon']
)
# Calculate total training steps
self.total_steps = len(self.train_dataloader) * self.config['num_epochs']
self.warmup_steps = int(self.total_steps * self.config['warmup_ratio'])
# Cosine learning rate scheduler with warmup
self.scheduler = get_cosine_schedule_with_warmup(
self.optimizer,
num_warmup_steps=self.warmup_steps,
num_training_steps=self.total_steps
)
def setup_logging(self):
"""Setup logging with wandb"""
if self.rank == 0:
wandb.init(
project=self.config['project_name'],
name=self.config['run_name'],
config=self.config
)
def save_checkpoint(self, epoch: int, step: int, loss: float):
"""Save model checkpoint"""
if self.rank == 0:
checkpoint_dir = Path(self.config['checkpoint_dir'])
checkpoint_dir.mkdir(parents=True, exist_ok=True)
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
checkpoint = {
'epoch': epoch,
'step': step,
'model_state_dict': model_to_save.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'loss': loss,
'config': self.config
}
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}_step_{step}.pt'
torch.save(checkpoint, checkpoint_path)
# Keep only the last N checkpoints
checkpoints = sorted(checkpoint_dir.glob('checkpoint_*.pt'))
if len(checkpoints) > self.config['max_checkpoints']:
for old_checkpoint in checkpoints[:-self.config['max_checkpoints']]:
old_checkpoint.unlink()
print(f"Checkpoint saved: {checkpoint_path}")
def load_checkpoint(self, checkpoint_path: str):
"""Load model checkpoint"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
model_to_load = self.model.module if hasattr(self.model, 'module') else self.model
model_to_load.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return checkpoint['epoch'], checkpoint['step'], checkpoint['loss']
def compute_loss(self, batch):
"""Compute the loss for a batch"""
input_ids = batch['input_ids'].to(self.device)
labels = batch['labels'].to(self.device)
outputs = self.model(input_ids=input_ids, labels=labels)
return outputs['loss']
def train_epoch(self, epoch: int):
"""Train for one epoch"""
self.model.train()
total_loss = 0
num_batches = len(self.train_dataloader)
if self.is_distributed:
self.train_dataloader.sampler.set_epoch(epoch)
for step, batch in enumerate(self.train_dataloader):
# Forward pass
loss = self.compute_loss(batch)
# Backward pass
loss.backward()
# Gradient clipping
if self.config['max_grad_norm'] > 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['max_grad_norm']
)
# Optimizer step
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
total_loss += loss.item()
# Logging
if step % self.config['log_interval'] == 0 and self.rank == 0:
avg_loss = total_loss / (step + 1)
lr = self.scheduler.get_last_lr()[0]
print(f"Epoch {epoch}, Step {step}/{num_batches}, "
f"Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}, "
f"LR: {lr:.2e}")
wandb.log({
'train/loss': loss.item(),
'train/avg_loss': avg_loss,
'train/learning_rate': lr,
'train/epoch': epoch,
'train/step': step
})
# Save checkpoint
if step % self.config['save_interval'] == 0 and step > 0:
self.save_checkpoint(epoch, step, loss.item())
return total_loss / num_batches
def validate(self, epoch: int):
"""Validate the model"""
self.model.eval()
total_loss = 0
num_batches = len(self.val_dataloader)
with torch.no_grad():
for step, batch in enumerate(self.val_dataloader):
loss = self.compute_loss(batch)
total_loss += loss.item()
avg_loss = total_loss / num_batches
perplexity = math.exp(avg_loss)
if self.rank == 0:
print(f"Validation - Epoch {epoch}, Loss: {avg_loss:.4f}, "
f"Perplexity: {perplexity:.2f}")
wandb.log({
'val/loss': avg_loss,
'val/perplexity': perplexity,
'val/epoch': epoch
})
return avg_loss
def train(self):
"""Main training loop"""
print("Starting training...")
best_val_loss = float('inf')
for epoch in range(self.config['num_epochs']):
print(f"\n=== Epoch {epoch + 1}/{self.config['num_epochs']} ===")
# Training
train_loss = self.train_epoch(epoch)
# Validation
val_loss = self.validate(epoch)
# Save best model
if val_loss < best_val_loss and self.rank == 0:
best_val_loss = val_loss
best_model_path = Path(self.config['checkpoint_dir']) / 'best_model.pt'
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
torch.save(model_to_save.state_dict(), best_model_path)
print(f"Best model saved with validation loss: {best_val_loss:.4f}")
# Save epoch checkpoint
self.save_checkpoint(epoch, len(self.train_dataloader), train_loss)
if self.rank == 0:
print("Training completed!")
wandb.finish()
class SimpleTokenizer:
"""Simple character-level tokenizer as fallback"""
def __init__(self, vocab_size: int = 50000):
self.vocab_size = vocab_size
self.pad_token_id = 0
def encode(self, text: str) -> List[int]:
# Simple character-level encoding
return [min(ord(c), self.vocab_size - 1) for c in text]
def parse_args():
parser = argparse.ArgumentParser(description='DeepSeek-Math Pretraining')
# Model configuration
parser.add_argument('--vocab_size', type=int, default=102400)
parser.add_argument('--hidden_size', type=int, default=4096)
parser.add_argument('--intermediate_size', type=int, default=11008)
parser.add_argument('--num_hidden_layers', type=int, default=30)
parser.add_argument('--num_attention_heads', type=int, default=32)
parser.add_argument('--max_position_embeddings', type=int, default=4096)
parser.add_argument('--max_length', type=int, default=2048)
# Training configuration
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--num_epochs', type=int, default=3)
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--weight_decay', type=float, default=0.1)
parser.add_argument('--adam_beta1', type=float, default=0.9)
parser.add_argument('--adam_beta2', type=float, default=0.95)
parser.add_argument('--adam_epsilon', type=float, default=1e-8)
parser.add_argument('--warmup_ratio', type=float, default=0.03)
parser.add_argument('--max_grad_norm', type=float, default=1.0)
# Data paths
parser.add_argument('--train_data_path', type=str, default='data/train.jsonl')
parser.add_argument('--val_data_path', type=str, default='data/val.jsonl')
# Logging and checkpointing
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
parser.add_argument('--log_interval', type=int, default=100)
parser.add_argument('--save_interval', type=int, default=1000)
parser.add_argument('--max_checkpoints', type=int, default=5)
parser.add_argument('--project_name', type=str, default='deepseek-math')
parser.add_argument('--run_name', type=str, default='pretraining')
# System configuration
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--rms_norm_eps', type=float, default=1e-6)
parser.add_argument('--rope_theta', type=float, default=10000.0)
return parser.parse_args()
def set_seed(seed: int = 42):
"""Set random seed for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main():
args = parse_args()
set_seed(42)
# Convert args to config dict
config = vars(args)
# Initialize trainer
trainer = DeepSeekMathTrainer(config)
# Start training
trainer.train()
if __name__ == "__main__":
main()
# Example usage with distributed training:
# torchrun --nproc_per_node=8 --nnodes=1 deepseek_math_training.py \
# --batch_size 4 \
# --learning_rate 1e-4 \
# --num_epochs 3 \
# --train_data_path /path/to/train.jsonl \
# --val_data_path /path/to/val.jsonl \
# --checkpoint_dir ./checkpoints \
# --project_name deepseek-math-7b \
# --run_name pretraining_run_1
# Data preprocessing script for mathematical datasets
def preprocess_math_data():
"""
Example preprocessing script for mathematical datasets
This would typically process:
- ArXiv papers in mathematics
- Mathematical problem-solution pairs
- Formal proofs
- Mathematical textbooks
- Code for mathematical computations
"""
import re
from pathlib import Path
def clean_math_text(text: str) -> str:
"""Clean and normalize mathematical text"""
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Normalize mathematical notation
text = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', r'(\1)/(\2)', text)
text = re.sub(r'\\sqrt\{([^}]+)\}', r'sqrt(\1)', text)
# Clean up common LaTeX commands
text = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', text)
text = re.sub(r'\\[a-zA-Z]+', '', text)
return text.strip()
def process_file(input_path: str, output_path: str):
"""Process a single file and save cleaned data"""
with open(input_path, 'r', encoding='utf-8') as f:
content = f.read()
# Split into chunks (could be paragraphs, sections, etc.)
chunks = content.split('\n\n')
processed_data = []
for chunk in chunks:
if len(chunk.strip()) > 50: # Filter out very short chunks
cleaned = clean_math_text(chunk)
if cleaned:
processed_data.append({'text': cleaned})
# Save as JSONL
with open(output_path, 'w', encoding='utf-8') as f:
for item in processed_data:
f.write(json.dumps(item) + '\n')
# Example usage
input_dir = Path('raw_data')
output_dir = Path('processed_data')
output_dir.mkdir(exist_ok=True)
for file_path in input_dir.glob('*.txt'):
output_path = output_dir / f"{file_path.stem}.jsonl"
process_file(str(file_path), str(output_path))
print(f"Processed {file_path} -> {output_path}")
if __name__ == "__main__":
# Uncomment to run data preprocessing
# preprocess_math_data()
# Run main training
main()