diff --git a/train/model.py b/train/model.py
new file mode 100644
index 0000000..509ac75
--- /dev/null
+++ b/train/model.py
@@ -0,0 +1,430 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from typing import Optional, Tuple, List
+from dataclasses import dataclass
+
+@dataclass
+class DeepSeekMathConfig:
+    vocab_size: int = 102400
+    hidden_size: int = 4096
+    intermediate_size: int = 11008
+    num_hidden_layers: int = 30
+    num_attention_heads: int = 32
+    num_key_value_heads: int = 32  # For grouped query attention
+    max_position_embeddings: int = 4096
+    rms_norm_eps: float = 1e-6
+    rope_theta: float = 10000.0
+    attention_dropout: float = 0.0
+    hidden_dropout: float = 0.0
+    use_cache: bool = True
+    rope_scaling: Optional[dict] = None
+    tie_word_embeddings: bool = False
+    # output_attentions:bool=True
+    # output_hidden_states:int=12
+
+class RMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        hidden_states = hidden_states.to(torch.float32)
+        variance = hidden_states.pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+        return self.weight * hidden_states.to(input_dtype)
+
+class RotaryPositionalEmbedding(nn.Module):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+        super().__init__()
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+    def forward(self, x, seq_len=None):
+        if seq_len is None:
+            seq_len = x.shape[-2]
+        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
+        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+        emb = torch.cat((freqs, freqs), dim=-1)
+        return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
+
+def rotate_half(x):
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=-1)
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+    cos = cos[position_ids].unsqueeze(1)
+    sin = sin[position_ids].unsqueeze(1)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+class DeepSeekMathAttention(nn.Module):
+    def __init__(self, config: DeepSeekMathConfig):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+
+        if (self.head_dim * self.num_heads) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads})."
+            )
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+
+        self.rotary_emb = RotaryPositionalEmbedding(
+            self.head_dim,
+            max_position_embeddings=self.max_position_embeddings,
+            base=self.rope_theta,
+        )
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+        if past_key_value is not None:
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        # Repeat k/v heads if n_kv_heads < n_heads
+        key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
+        value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:
+            attn_weights = attn_weights + attention_mask
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+class DeepSeekMathMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size
+        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = nn.SiLU()
+
+    def forward(self, x):
+        # SwiGLU activation function
+        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+class DeepSeekMathDecoderLayer(nn.Module):
+    def __init__(self, config: DeepSeekMathConfig):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = DeepSeekMathAttention(config=config)
+        self.mlp = DeepSeekMathMLP(config)
+        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+class DeepSeekMathModel(nn.Module):
+    def __init__(self, config: DeepSeekMathConfig):
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.vocab_size - 1
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+        self.layers = nn.ModuleList([DeepSeekMathDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+    ):
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            batch_size, seq_length = input_ids.shape
+        elif inputs_embeds is not None:
+            batch_size, seq_length, _ = inputs_embeds.shape
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        if past_key_values is not None:
+            past_key_values_length = past_key_values[0][0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+
+        if position_ids is None:
+            device = input_ids.device if input_ids is not None else inputs_embeds.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        # Attention mask
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+            )
+        attention_mask = self._prepare_decoder_attention_mask(
+            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+        )
+
+        hidden_states = inputs_embeds
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = () if use_cache else None
+
+        for idx, decoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            layer_outputs = decoder_layer(
+                hidden_states,
+                attention_mask=attention_mask,
+                position_ids=position_ids,
+                past_key_value=past_key_value,
+                output_attentions=output_attentions,
+                use_cache=use_cache,
+            )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+
+        return {
+            "last_hidden_state": hidden_states,
+            "past_key_values": next_cache,
+            "hidden_states": all_hidden_states,
+            "attentions": all_self_attns,
+        }
+
+    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+        # Create causal mask
+        batch_size, seq_length = input_shape
+        causal_mask = torch.full((seq_length, seq_length), fill_value=float("-inf"), device=inputs_embeds.device)
+        causal_mask = torch.triu(causal_mask, diagonal=1)
+        causal_mask = causal_mask.to(inputs_embeds.dtype)
+
+        if past_key_values_length > 0:
+            causal_mask = torch.cat(
+                [torch.zeros(seq_length, past_key_values_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device),
+                 causal_mask], dim=-1
+            )
+
+        expanded_attn_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length + past_key_values_length)
+        expanded_attn_mask = expanded_attn_mask.to(inputs_embeds.dtype)
+        expanded_attn_mask = (1.0 - expanded_attn_mask) * torch.finfo(inputs_embeds.dtype).min
+
+        return expanded_attn_mask + causal_mask
+
+class DeepSeekMathForCausalLM(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.model = DeepSeekMathModel(config)
+        self.vocab_size = config.vocab_size
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+    ):
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+        )
+
+        hidden_states = outputs["last_hidden_state"]
+        logits = self.lm_head(hidden_states)
+        logits = logits.float()
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = nn.CrossEntropyLoss()
+            shift_logits = shift_logits.view(-1, self.vocab_size)
+            shift_labels = shift_labels.view(-1)
+            # Enable model parallelism
+            shift_labels = shift_labels.to(shift_logits.device)
+            loss = loss_fct(shift_logits, shift_labels)
+
+        return {
+            "loss": loss,
+            "logits": logits,
+            "past_key_values": outputs["past_key_values"],
+            "hidden_states": outputs["hidden_states"],
+            "attentions": outputs["attentions"],
+        }
+
+
+# config = DeepSeekMathConfig()
+# model = DeepSeekMathForCausalLM(config)
+
+# # Print model info
+# total_params = sum(p.numel() for p in model.parameters())
+# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+# print(f"Total parameters: {total_params:,}")
+# print(f"Trainable parameters: {trainable_params:,}")
+# print(f"Model size: ~{total_params * 4 / 1e9:.1f}B parameters")
\ No newline at end of file
diff --git a/train/train.py b/train/train.py
new file mode 100644
index 0000000..715fce7
--- /dev/null
+++ b/train/train.py
@@ -0,0 +1,537 @@
+import os
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data.dataset import Dataset
+import torch.optim as optim
+from transformers import get_cosine_schedule_with_warmup
+import wandb
+import json
+import time
+import math
+import random
+import numpy as np
+from typing import Dict, List, Optional
+import argparse
+from pathlib import Path
+import tiktoken
+from model import DeepSeekMathConfig, DeepSeekMathForCausalLM
+
+class MathDataset(Dataset):
+    """Dataset for mathematical text data"""
+    
+    def __init__(self, data_path: str, tokenizer, max_length: int = 4096, split: str = "train"):
+        self.tokenizer = tokenizer
+        self.max_length = max_length
+        self.data = []
+        
+        # Load mathematical datasets
+        # This would include sources like:
+        # - Mathematical papers (arXiv)
+        # - Mathematical problem-solution pairs
+        # - Formal mathematical proofs
+        # - Mathematical textbooks
+        # - Code for mathematical computations
+        
+        if os.path.exists(data_path):
+            with open(data_path, 'r', encoding='utf-8') as f:
+                for line in f:
+                    try:
+                        sample = json.loads(line)
+                        if 'text' in sample:
+                            self.data.append(sample['text'])
+                    except json.JSONDecodeError:
+                        continue
+        else:
+            # Dummy data for demonstration
+            self.data = self._generate_dummy_math_data()
+    
+    def _generate_dummy_math_data(self):
+        """Generate dummy mathematical data for demonstration"""
+        dummy_data = [
+            "Theorem: For any real numbers a and b, (a + b)² = a² + 2ab + b². Proof: (a + b)² = (a + b)(a + b) = a(a + b) + b(a + b) = a² + ab + ba + b² = a² + 2ab + b².",
+            "Problem: Solve the equation 2x + 5 = 13. Solution: 2x + 5 = 13, 2x = 13 - 5, 2x = 8, x = 4.",
+            "Definition: A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.",
+            "Lemma: If n is a composite number, then n has a prime divisor less than or equal to √n.",
+            "Calculate the derivative of f(x) = x³ + 2x² - 5x + 1. f'(x) = 3x² + 4x - 5.",
+        ] * 1000  # Repeat for more data
+        return dummy_data
+    
+    def __len__(self):
+        return len(self.data)
+    
+    def __getitem__(self, idx):
+        text = self.data[idx]
+        
+        # Tokenize the text
+        tokens = self.tokenizer.encode(text)
+        
+        # Truncate or pad to max_length
+        if len(tokens) > self.max_length:
+            tokens = tokens[:self.max_length]
+        else:
+            # Pad with tokenizer's pad token if available, otherwise use 0
+            pad_token = getattr(self.tokenizer, 'pad_token_id', 0)
+            tokens.extend([pad_token] * (self.max_length - len(tokens)))
+        
+        return {
+            'input_ids': torch.tensor(tokens, dtype=torch.long),
+            'labels': torch.tensor(tokens, dtype=torch.long)
+        }
+
+class DeepSeekMathTrainer:
+    def __init__(self, config: Dict):
+        self.config = config
+        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        self.setup_distributed_training()
+        self.setup_model()
+        self.setup_tokenizer()
+        self.setup_datasets()
+        self.setup_optimizer()
+        self.setup_logging()
+        
+    def setup_distributed_training(self):
+        """Setup distributed training if available"""
+        self.is_distributed = False
+        if 'WORLD_SIZE' in os.environ:
+            self.is_distributed = True
+            self.world_size = int(os.environ['WORLD_SIZE'])
+            self.rank = int(os.environ['RANK'])
+            self.local_rank = int(os.environ['LOCAL_RANK'])
+            
+            # Initialize distributed training
+            dist.init_process_group(backend='nccl')
+            torch.cuda.set_device(self.local_rank)
+            self.device = torch.device(f'cuda:{self.local_rank}')
+        else:
+            self.world_size = 1
+            self.rank = 0
+            self.local_rank = 0
+    
+    def setup_model(self):
+        """Initialize the DeepSeek-Math model"""
+        model_config = DeepSeekMathConfig(
+            vocab_size=self.config['vocab_size'],
+            hidden_size=self.config['hidden_size'],
+            intermediate_size=self.config['intermediate_size'],
+            num_hidden_layers=self.config['num_hidden_layers'],
+            num_attention_heads=self.config['num_attention_heads'],
+            max_position_embeddings=self.config['max_position_embeddings'],
+            rms_norm_eps=self.config['rms_norm_eps'],
+            rope_theta=self.config['rope_theta'],
+        )
+        
+        self.model = DeepSeekMathForCausalLM(model_config)
+        self.model = self.model.to(self.device)
+        
+        if self.is_distributed:
+            self.model = DDP(self.model, device_ids=[self.local_rank])
+        
+        # Print model statistics
+        if self.rank == 0:
+            total_params = sum(p.numel() for p in self.model.parameters())
+            trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+            print(f"Total parameters: {total_params:,}")
+            print(f"Trainable parameters: {trainable_params:,}")
+    
+    def setup_tokenizer(self):
+        """Setup tokenizer (using tiktoken for GPT-4 tokenizer)"""
+        try:
+            self.tokenizer = tiktoken.get_encoding("cl100k_base")
+        except:
+            # Fallback to a simple character-level tokenizer
+            self.tokenizer = SimpleTokenizer(vocab_size=self.config['vocab_size'])
+    
+    def setup_datasets(self):
+        """Setup training and validation datasets"""
+        train_dataset = MathDataset(
+            data_path=self.config['train_data_path'],
+            tokenizer=self.tokenizer,
+            max_length=self.config['max_length'],
+            split='train'
+        )
+        
+        val_dataset = MathDataset(
+            data_path=self.config['val_data_path'],
+            tokenizer=self.tokenizer,
+            max_length=self.config['max_length'],
+            split='val'
+        )
+        
+        # Setup distributed samplers
+        train_sampler = DistributedSampler(train_dataset) if self.is_distributed else None
+        val_sampler = DistributedSampler(val_dataset) if self.is_distributed else None
+        
+        self.train_dataloader = DataLoader(
+            train_dataset,
+            batch_size=self.config['batch_size'],
+            sampler=train_sampler,
+            shuffle=(train_sampler is None),
+            num_workers=self.config['num_workers'],
+            pin_memory=True
+        )
+        
+        self.val_dataloader = DataLoader(
+            val_dataset,
+            batch_size=self.config['batch_size'],
+            sampler=val_sampler,
+            shuffle=False,
+            num_workers=self.config['num_workers'],
+            pin_memory=True
+        )
+    
+    def setup_optimizer(self):
+        """Setup optimizer and learning rate scheduler"""
+        # AdamW optimizer with weight decay
+        no_decay = ['bias', 'LayerNorm.weight', 'layernorm.weight']
+        optimizer_grouped_parameters = [
+            {
+                'params': [p for n, p in self.model.named_parameters() 
+                          if not any(nd in n for nd in no_decay)],
+                'weight_decay': self.config['weight_decay']
+            },
+            {
+                'params': [p for n, p in self.model.named_parameters() 
+                          if any(nd in n for nd in no_decay)],
+                'weight_decay': 0.0
+            }
+        ]
+        
+        self.optimizer = optim.AdamW(
+            optimizer_grouped_parameters,
+            lr=self.config['learning_rate'],
+            betas=(self.config['adam_beta1'], self.config['adam_beta2']),
+            eps=self.config['adam_epsilon']
+        )
+        
+        # Calculate total training steps
+        self.total_steps = len(self.train_dataloader) * self.config['num_epochs']
+        self.warmup_steps = int(self.total_steps * self.config['warmup_ratio'])
+        
+        # Cosine learning rate scheduler with warmup
+        self.scheduler = get_cosine_schedule_with_warmup(
+            self.optimizer,
+            num_warmup_steps=self.warmup_steps,
+            num_training_steps=self.total_steps
+        )
+    
+    def setup_logging(self):
+        """Setup logging with wandb"""
+        if self.rank == 0:
+            wandb.init(
+                project=self.config['project_name'],
+                name=self.config['run_name'],
+                config=self.config
+            )
+    
+    def save_checkpoint(self, epoch: int, step: int, loss: float):
+        """Save model checkpoint"""
+        if self.rank == 0:
+            checkpoint_dir = Path(self.config['checkpoint_dir'])
+            checkpoint_dir.mkdir(parents=True, exist_ok=True)
+            
+            model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
+            
+            checkpoint = {
+                'epoch': epoch,
+                'step': step,
+                'model_state_dict': model_to_save.state_dict(),
+                'optimizer_state_dict': self.optimizer.state_dict(),
+                'scheduler_state_dict': self.scheduler.state_dict(),
+                'loss': loss,
+                'config': self.config
+            }
+            
+            checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}_step_{step}.pt'
+            torch.save(checkpoint, checkpoint_path)
+            
+            # Keep only the last N checkpoints
+            checkpoints = sorted(checkpoint_dir.glob('checkpoint_*.pt'))
+            if len(checkpoints) > self.config['max_checkpoints']:
+                for old_checkpoint in checkpoints[:-self.config['max_checkpoints']]:
+                    old_checkpoint.unlink()
+            
+            print(f"Checkpoint saved: {checkpoint_path}")
+    
+    def load_checkpoint(self, checkpoint_path: str):
+        """Load model checkpoint"""
+        checkpoint = torch.load(checkpoint_path, map_location=self.device)
+        
+        model_to_load = self.model.module if hasattr(self.model, 'module') else self.model
+        model_to_load.load_state_dict(checkpoint['model_state_dict'])
+        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
+        
+        return checkpoint['epoch'], checkpoint['step'], checkpoint['loss']
+    
+    def compute_loss(self, batch):
+        """Compute the loss for a batch"""
+        input_ids = batch['input_ids'].to(self.device)
+        labels = batch['labels'].to(self.device)
+        
+        outputs = self.model(input_ids=input_ids, labels=labels)
+        return outputs['loss']
+    
+    def train_epoch(self, epoch: int):
+        """Train for one epoch"""
+        self.model.train()
+        total_loss = 0
+        num_batches = len(self.train_dataloader)
+        
+        if self.is_distributed:
+            self.train_dataloader.sampler.set_epoch(epoch)
+        
+        for step, batch in enumerate(self.train_dataloader):
+            # Forward pass
+            loss = self.compute_loss(batch)
+            
+            # Backward pass
+            loss.backward()
+            
+            # Gradient clipping
+            if self.config['max_grad_norm'] > 0:
+                torch.nn.utils.clip_grad_norm_(
+                    self.model.parameters(), 
+                    self.config['max_grad_norm']
+                )
+            
+            # Optimizer step
+            self.optimizer.step()
+            self.scheduler.step()
+            self.optimizer.zero_grad()
+            
+            total_loss += loss.item()
+            
+            # Logging
+            if step % self.config['log_interval'] == 0 and self.rank == 0:
+                avg_loss = total_loss / (step + 1)
+                lr = self.scheduler.get_last_lr()[0]
+                
+                print(f"Epoch {epoch}, Step {step}/{num_batches}, "
+                      f"Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}, "
+                      f"LR: {lr:.2e}")
+                
+                wandb.log({
+                    'train/loss': loss.item(),
+                    'train/avg_loss': avg_loss,
+                    'train/learning_rate': lr,
+                    'train/epoch': epoch,
+                    'train/step': step
+                })
+            
+            # Save checkpoint
+            if step % self.config['save_interval'] == 0 and step > 0:
+                self.save_checkpoint(epoch, step, loss.item())
+        
+        return total_loss / num_batches
+    
+    def validate(self, epoch: int):
+        """Validate the model"""
+        self.model.eval()
+        total_loss = 0
+        num_batches = len(self.val_dataloader)
+        
+        with torch.no_grad():
+            for step, batch in enumerate(self.val_dataloader):
+                loss = self.compute_loss(batch)
+                total_loss += loss.item()
+        
+        avg_loss = total_loss / num_batches
+        perplexity = math.exp(avg_loss)
+        
+        if self.rank == 0:
+            print(f"Validation - Epoch {epoch}, Loss: {avg_loss:.4f}, "
+                  f"Perplexity: {perplexity:.2f}")
+            
+            wandb.log({
+                'val/loss': avg_loss,
+                'val/perplexity': perplexity,
+                'val/epoch': epoch
+            })
+        
+        return avg_loss
+    
+    def train(self):
+        """Main training loop"""
+        print("Starting training...")
+        
+        best_val_loss = float('inf')
+        
+        for epoch in range(self.config['num_epochs']):
+            print(f"\n=== Epoch {epoch + 1}/{self.config['num_epochs']} ===")
+            
+            # Training
+            train_loss = self.train_epoch(epoch)
+            
+            # Validation
+            val_loss = self.validate(epoch)
+            
+            # Save best model
+            if val_loss < best_val_loss and self.rank == 0:
+                best_val_loss = val_loss
+                best_model_path = Path(self.config['checkpoint_dir']) / 'best_model.pt'
+                model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
+                torch.save(model_to_save.state_dict(), best_model_path)
+                print(f"Best model saved with validation loss: {best_val_loss:.4f}")
+            
+            # Save epoch checkpoint
+            self.save_checkpoint(epoch, len(self.train_dataloader), train_loss)
+        
+        if self.rank == 0:
+            print("Training completed!")
+            wandb.finish()
+
+class SimpleTokenizer:
+    """Simple character-level tokenizer as fallback"""
+    def __init__(self, vocab_size: int = 50000):
+        self.vocab_size = vocab_size
+        self.pad_token_id = 0
+        
+    def encode(self, text: str) -> List[int]:
+        # Simple character-level encoding
+        return [min(ord(c), self.vocab_size - 1) for c in text]
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='DeepSeek-Math Pretraining')
+    
+    # Model configuration
+    parser.add_argument('--vocab_size', type=int, default=102400)
+    parser.add_argument('--hidden_size', type=int, default=4096)
+    parser.add_argument('--intermediate_size', type=int, default=11008)
+    parser.add_argument('--num_hidden_layers', type=int, default=30)
+    parser.add_argument('--num_attention_heads', type=int, default=32)
+    parser.add_argument('--max_position_embeddings', type=int, default=4096)
+    parser.add_argument('--max_length', type=int, default=2048)
+    
+    # Training configuration
+    parser.add_argument('--batch_size', type=int, default=8)
+    parser.add_argument('--num_epochs', type=int, default=3)
+    parser.add_argument('--learning_rate', type=float, default=1e-4)
+    parser.add_argument('--weight_decay', type=float, default=0.1)
+    parser.add_argument('--adam_beta1', type=float, default=0.9)
+    parser.add_argument('--adam_beta2', type=float, default=0.95)
+    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
+    parser.add_argument('--warmup_ratio', type=float, default=0.03)
+    parser.add_argument('--max_grad_norm', type=float, default=1.0)
+    
+    # Data paths
+    parser.add_argument('--train_data_path', type=str, default='data/train.jsonl')
+    parser.add_argument('--val_data_path', type=str, default='data/val.jsonl')
+    
+    # Logging and checkpointing
+    parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
+    parser.add_argument('--log_interval', type=int, default=100)
+    parser.add_argument('--save_interval', type=int, default=1000)
+    parser.add_argument('--max_checkpoints', type=int, default=5)
+    parser.add_argument('--project_name', type=str, default='deepseek-math')
+    parser.add_argument('--run_name', type=str, default='pretraining')
+    
+    # System configuration
+    parser.add_argument('--num_workers', type=int, default=4)
+    parser.add_argument('--rms_norm_eps', type=float, default=1e-6)
+    parser.add_argument('--rope_theta', type=float, default=10000.0)
+    
+    return parser.parse_args()
+
+def set_seed(seed: int = 42):
+    """Set random seed for reproducibility"""
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+def main():
+    args = parse_args()
+    set_seed(42)
+    
+    # Convert args to config dict
+    config = vars(args)
+    
+    # Initialize trainer
+    trainer = DeepSeekMathTrainer(config)
+    
+    # Start training
+    trainer.train()
+
+if __name__ == "__main__":
+    main()
+
+# Example usage with distributed training:
+# torchrun --nproc_per_node=8 --nnodes=1 deepseek_math_training.py \
+#     --batch_size 4 \
+#     --learning_rate 1e-4 \
+#     --num_epochs 3 \
+#     --train_data_path /path/to/train.jsonl \
+#     --val_data_path /path/to/val.jsonl \
+#     --checkpoint_dir ./checkpoints \
+#     --project_name deepseek-math-7b \
+#     --run_name pretraining_run_1
+
+# Data preprocessing script for mathematical datasets
+def preprocess_math_data():
+    """
+    Example preprocessing script for mathematical datasets
+    This would typically process:
+    - ArXiv papers in mathematics
+    - Mathematical problem-solution pairs
+    - Formal proofs
+    - Mathematical textbooks
+    - Code for mathematical computations
+    """
+    
+    import re
+    from pathlib import Path
+    
+    def clean_math_text(text: str) -> str:
+        """Clean and normalize mathematical text"""
+        # Remove excessive whitespace
+        text = re.sub(r'\s+', ' ', text)
+        
+        # Normalize mathematical notation
+        text = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', r'(\1)/(\2)', text)
+        text = re.sub(r'\\sqrt\{([^}]+)\}', r'sqrt(\1)', text)
+        
+        # Clean up common LaTeX commands
+        text = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', text)
+        text = re.sub(r'\\[a-zA-Z]+', '', text)
+        
+        return text.strip()
+    
+    def process_file(input_path: str, output_path: str):
+        """Process a single file and save cleaned data"""
+        with open(input_path, 'r', encoding='utf-8') as f:
+            content = f.read()
+        
+        # Split into chunks (could be paragraphs, sections, etc.)
+        chunks = content.split('\n\n')
+        
+        processed_data = []
+        for chunk in chunks:
+            if len(chunk.strip()) > 50:  # Filter out very short chunks
+                cleaned = clean_math_text(chunk)
+                if cleaned:
+                    processed_data.append({'text': cleaned})
+        
+        # Save as JSONL
+        with open(output_path, 'w', encoding='utf-8') as f:
+            for item in processed_data:
+                f.write(json.dumps(item) + '\n')
+    
+    # Example usage
+    input_dir = Path('raw_data')
+    output_dir = Path('processed_data')
+    output_dir.mkdir(exist_ok=True)
+    
+    for file_path in input_dir.glob('*.txt'):
+        output_path = output_dir / f"{file_path.stem}.jsonl"
+        process_file(str(file_path), str(output_path))
+        print(f"Processed {file_path} -> {output_path}")
+
+if __name__ == "__main__":
+    # Uncomment to run data preprocessing
+    # preprocess_math_data()
+    
+    # Run main training
+    main()
\ No newline at end of file