mirror of
https://github.com/deepseek-ai/DeepSeek-Math.git
synced 2025-06-20 16:33:47 -04:00
Implement DeepSeek-Math model and training pipeline with dataset handling and distributed training support
This commit is contained in:
parent
b8b0f8ce09
commit
b53e984052
428
train/model.py
Normal file
428
train/model.py
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepSeekMathConfig:
|
||||||
|
vocab_size: int = 102400
|
||||||
|
hidden_size: int = 4096
|
||||||
|
intermediate_size: int = 11008
|
||||||
|
num_hidden_layers: int = 30
|
||||||
|
num_attention_heads: int = 32
|
||||||
|
num_key_value_heads: int = 32 # For grouped query attention
|
||||||
|
max_position_embeddings: int = 4096
|
||||||
|
rms_norm_eps: float = 1e-6
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
attention_dropout: float = 0.0
|
||||||
|
hidden_dropout: float = 0.0
|
||||||
|
use_cache: bool = True
|
||||||
|
rope_scaling: Optional[dict] = None
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
class RotaryPositionalEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x, seq_len=None):
|
||||||
|
if seq_len is None:
|
||||||
|
seq_len = x.shape[-2]
|
||||||
|
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
|
||||||
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
|
cos = cos[position_ids].unsqueeze(1)
|
||||||
|
sin = sin[position_ids].unsqueeze(1)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
class DeepSeekMathAttention(nn.Module):
|
||||||
|
def __init__(self, config: DeepSeekMathConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
self.rope_theta = config.rope_theta
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
self.rotary_emb = RotaryPositionalEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
base=self.rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# Repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
||||||
|
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
class DeepSeekMathMLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
self.act_fn = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# SwiGLU activation function
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
class DeepSeekMathDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config: DeepSeekMathConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = DeepSeekMathAttention(config=config)
|
||||||
|
self.mlp = DeepSeekMathMLP(config)
|
||||||
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
class DeepSeekMathModel(nn.Module):
|
||||||
|
def __init__(self, config: DeepSeekMathConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.padding_idx = config.vocab_size - 1
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||||
|
self.layers = nn.ModuleList([DeepSeekMathDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.embed_tokens = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
seq_length_with_past = seq_length
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values[0][0].shape[2]
|
||||||
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Attention mask
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_decoder_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||||
|
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"last_hidden_state": hidden_states,
|
||||||
|
"past_key_values": next_cache,
|
||||||
|
"hidden_states": all_hidden_states,
|
||||||
|
"attentions": all_self_attns,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
||||||
|
# Create causal mask
|
||||||
|
batch_size, seq_length = input_shape
|
||||||
|
causal_mask = torch.full((seq_length, seq_length), fill_value=float("-inf"), device=inputs_embeds.device)
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask = causal_mask.to(inputs_embeds.dtype)
|
||||||
|
|
||||||
|
if past_key_values_length > 0:
|
||||||
|
causal_mask = torch.cat(
|
||||||
|
[torch.zeros(seq_length, past_key_values_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device),
|
||||||
|
causal_mask], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
expanded_attn_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_length, seq_length + past_key_values_length)
|
||||||
|
expanded_attn_mask = expanded_attn_mask.to(inputs_embeds.dtype)
|
||||||
|
expanded_attn_mask = (1.0 - expanded_attn_mask) * torch.finfo(inputs_embeds.dtype).min
|
||||||
|
|
||||||
|
return expanded_attn_mask + causal_mask
|
||||||
|
|
||||||
|
class DeepSeekMathForCausalLM(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.model = DeepSeekMathModel(config)
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
self.model.embed_tokens = value
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs["last_hidden_state"]
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
logits = logits.float()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": loss,
|
||||||
|
"logits": logits,
|
||||||
|
"past_key_values": outputs["past_key_values"],
|
||||||
|
"hidden_states": outputs["hidden_states"],
|
||||||
|
"attentions": outputs["attentions"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# config = DeepSeekMathConfig()
|
||||||
|
# model = DeepSeekMathForCausalLM(config)
|
||||||
|
|
||||||
|
# # Print model info
|
||||||
|
# total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
# trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
# print(f"Total parameters: {total_params:,}")
|
||||||
|
# print(f"Trainable parameters: {trainable_params:,}")
|
||||||
|
# print(f"Model size: ~{total_params * 4 / 1e9:.1f}B parameters")
|
537
train/train.py
Normal file
537
train/train.py
Normal file
@ -0,0 +1,537 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
import torch.optim as optim
|
||||||
|
from transformers import get_cosine_schedule_with_warmup
|
||||||
|
import wandb
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import tiktoken
|
||||||
|
from model import DeepSeekMathConfig, DeepSeekMathForCausalLM
|
||||||
|
|
||||||
|
class MathDataset(Dataset):
|
||||||
|
"""Dataset for mathematical text data"""
|
||||||
|
|
||||||
|
def __init__(self, data_path: str, tokenizer, max_length: int = 4096, split: str = "train"):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
self.data = []
|
||||||
|
|
||||||
|
# Load mathematical datasets
|
||||||
|
# This would include sources like:
|
||||||
|
# - Mathematical papers (arXiv)
|
||||||
|
# - Mathematical problem-solution pairs
|
||||||
|
# - Formal mathematical proofs
|
||||||
|
# - Mathematical textbooks
|
||||||
|
# - Code for mathematical computations
|
||||||
|
|
||||||
|
if os.path.exists(data_path):
|
||||||
|
with open(data_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
sample = json.loads(line)
|
||||||
|
if 'text' in sample:
|
||||||
|
self.data.append(sample['text'])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Dummy data for demonstration
|
||||||
|
self.data = self._generate_dummy_math_data()
|
||||||
|
|
||||||
|
def _generate_dummy_math_data(self):
|
||||||
|
"""Generate dummy mathematical data for demonstration"""
|
||||||
|
dummy_data = [
|
||||||
|
"Theorem: For any real numbers a and b, (a + b)² = a² + 2ab + b². Proof: (a + b)² = (a + b)(a + b) = a(a + b) + b(a + b) = a² + ab + ba + b² = a² + 2ab + b².",
|
||||||
|
"Problem: Solve the equation 2x + 5 = 13. Solution: 2x + 5 = 13, 2x = 13 - 5, 2x = 8, x = 4.",
|
||||||
|
"Definition: A prime number is a natural number greater than 1 that has no positive divisors other than 1 and itself.",
|
||||||
|
"Lemma: If n is a composite number, then n has a prime divisor less than or equal to √n.",
|
||||||
|
"Calculate the derivative of f(x) = x³ + 2x² - 5x + 1. f'(x) = 3x² + 4x - 5.",
|
||||||
|
] * 1000 # Repeat for more data
|
||||||
|
return dummy_data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
text = self.data[idx]
|
||||||
|
|
||||||
|
# Tokenize the text
|
||||||
|
tokens = self.tokenizer.encode(text)
|
||||||
|
|
||||||
|
# Truncate or pad to max_length
|
||||||
|
if len(tokens) > self.max_length:
|
||||||
|
tokens = tokens[:self.max_length]
|
||||||
|
else:
|
||||||
|
# Pad with tokenizer's pad token if available, otherwise use 0
|
||||||
|
pad_token = getattr(self.tokenizer, 'pad_token_id', 0)
|
||||||
|
tokens.extend([pad_token] * (self.max_length - len(tokens)))
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids': torch.tensor(tokens, dtype=torch.long),
|
||||||
|
'labels': torch.tensor(tokens, dtype=torch.long)
|
||||||
|
}
|
||||||
|
|
||||||
|
class DeepSeekMathTrainer:
|
||||||
|
def __init__(self, config: Dict):
|
||||||
|
self.config = config
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
self.setup_distributed_training()
|
||||||
|
self.setup_model()
|
||||||
|
self.setup_tokenizer()
|
||||||
|
self.setup_datasets()
|
||||||
|
self.setup_optimizer()
|
||||||
|
self.setup_logging()
|
||||||
|
|
||||||
|
def setup_distributed_training(self):
|
||||||
|
"""Setup distributed training if available"""
|
||||||
|
self.is_distributed = False
|
||||||
|
if 'WORLD_SIZE' in os.environ:
|
||||||
|
self.is_distributed = True
|
||||||
|
self.world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
self.rank = int(os.environ['RANK'])
|
||||||
|
self.local_rank = int(os.environ['LOCAL_RANK'])
|
||||||
|
|
||||||
|
# Initialize distributed training
|
||||||
|
dist.init_process_group(backend='nccl')
|
||||||
|
torch.cuda.set_device(self.local_rank)
|
||||||
|
self.device = torch.device(f'cuda:{self.local_rank}')
|
||||||
|
else:
|
||||||
|
self.world_size = 1
|
||||||
|
self.rank = 0
|
||||||
|
self.local_rank = 0
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
"""Initialize the DeepSeek-Math model"""
|
||||||
|
model_config = DeepSeekMathConfig(
|
||||||
|
vocab_size=self.config['vocab_size'],
|
||||||
|
hidden_size=self.config['hidden_size'],
|
||||||
|
intermediate_size=self.config['intermediate_size'],
|
||||||
|
num_hidden_layers=self.config['num_hidden_layers'],
|
||||||
|
num_attention_heads=self.config['num_attention_heads'],
|
||||||
|
max_position_embeddings=self.config['max_position_embeddings'],
|
||||||
|
rms_norm_eps=self.config['rms_norm_eps'],
|
||||||
|
rope_theta=self.config['rope_theta'],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model = DeepSeekMathForCausalLM(model_config)
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
|
||||||
|
if self.is_distributed:
|
||||||
|
self.model = DDP(self.model, device_ids=[self.local_rank])
|
||||||
|
|
||||||
|
# Print model statistics
|
||||||
|
if self.rank == 0:
|
||||||
|
total_params = sum(p.numel() for p in self.model.parameters())
|
||||||
|
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
print(f"Total parameters: {total_params:,}")
|
||||||
|
print(f"Trainable parameters: {trainable_params:,}")
|
||||||
|
|
||||||
|
def setup_tokenizer(self):
|
||||||
|
"""Setup tokenizer (using tiktoken for GPT-4 tokenizer)"""
|
||||||
|
try:
|
||||||
|
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
except:
|
||||||
|
# Fallback to a simple character-level tokenizer
|
||||||
|
self.tokenizer = SimpleTokenizer(vocab_size=self.config['vocab_size'])
|
||||||
|
|
||||||
|
def setup_datasets(self):
|
||||||
|
"""Setup training and validation datasets"""
|
||||||
|
train_dataset = MathDataset(
|
||||||
|
data_path=self.config['train_data_path'],
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
max_length=self.config['max_length'],
|
||||||
|
split='train'
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataset = MathDataset(
|
||||||
|
data_path=self.config['val_data_path'],
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
max_length=self.config['max_length'],
|
||||||
|
split='val'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup distributed samplers
|
||||||
|
train_sampler = DistributedSampler(train_dataset) if self.is_distributed else None
|
||||||
|
val_sampler = DistributedSampler(val_dataset) if self.is_distributed else None
|
||||||
|
|
||||||
|
self.train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=self.config['batch_size'],
|
||||||
|
sampler=train_sampler,
|
||||||
|
shuffle=(train_sampler is None),
|
||||||
|
num_workers=self.config['num_workers'],
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=self.config['batch_size'],
|
||||||
|
sampler=val_sampler,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.config['num_workers'],
|
||||||
|
pin_memory=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def setup_optimizer(self):
|
||||||
|
"""Setup optimizer and learning rate scheduler"""
|
||||||
|
# AdamW optimizer with weight decay
|
||||||
|
no_decay = ['bias', 'LayerNorm.weight', 'layernorm.weight']
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
'params': [p for n, p in self.model.named_parameters()
|
||||||
|
if not any(nd in n for nd in no_decay)],
|
||||||
|
'weight_decay': self.config['weight_decay']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': [p for n, p in self.model.named_parameters()
|
||||||
|
if any(nd in n for nd in no_decay)],
|
||||||
|
'weight_decay': 0.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
self.optimizer = optim.AdamW(
|
||||||
|
optimizer_grouped_parameters,
|
||||||
|
lr=self.config['learning_rate'],
|
||||||
|
betas=(self.config['adam_beta1'], self.config['adam_beta2']),
|
||||||
|
eps=self.config['adam_epsilon']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate total training steps
|
||||||
|
self.total_steps = len(self.train_dataloader) * self.config['num_epochs']
|
||||||
|
self.warmup_steps = int(self.total_steps * self.config['warmup_ratio'])
|
||||||
|
|
||||||
|
# Cosine learning rate scheduler with warmup
|
||||||
|
self.scheduler = get_cosine_schedule_with_warmup(
|
||||||
|
self.optimizer,
|
||||||
|
num_warmup_steps=self.warmup_steps,
|
||||||
|
num_training_steps=self.total_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def setup_logging(self):
|
||||||
|
"""Setup logging with wandb"""
|
||||||
|
if self.rank == 0:
|
||||||
|
wandb.init(
|
||||||
|
project=self.config['project_name'],
|
||||||
|
name=self.config['run_name'],
|
||||||
|
config=self.config
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_checkpoint(self, epoch: int, step: int, loss: float):
|
||||||
|
"""Save model checkpoint"""
|
||||||
|
if self.rank == 0:
|
||||||
|
checkpoint_dir = Path(self.config['checkpoint_dir'])
|
||||||
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
'epoch': epoch,
|
||||||
|
'step': step,
|
||||||
|
'model_state_dict': model_to_save.state_dict(),
|
||||||
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
|
'scheduler_state_dict': self.scheduler.state_dict(),
|
||||||
|
'loss': loss,
|
||||||
|
'config': self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch}_step_{step}.pt'
|
||||||
|
torch.save(checkpoint, checkpoint_path)
|
||||||
|
|
||||||
|
# Keep only the last N checkpoints
|
||||||
|
checkpoints = sorted(checkpoint_dir.glob('checkpoint_*.pt'))
|
||||||
|
if len(checkpoints) > self.config['max_checkpoints']:
|
||||||
|
for old_checkpoint in checkpoints[:-self.config['max_checkpoints']]:
|
||||||
|
old_checkpoint.unlink()
|
||||||
|
|
||||||
|
print(f"Checkpoint saved: {checkpoint_path}")
|
||||||
|
|
||||||
|
def load_checkpoint(self, checkpoint_path: str):
|
||||||
|
"""Load model checkpoint"""
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
model_to_load = self.model.module if hasattr(self.model, 'module') else self.model
|
||||||
|
model_to_load.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|
||||||
|
return checkpoint['epoch'], checkpoint['step'], checkpoint['loss']
|
||||||
|
|
||||||
|
def compute_loss(self, batch):
|
||||||
|
"""Compute the loss for a batch"""
|
||||||
|
input_ids = batch['input_ids'].to(self.device)
|
||||||
|
labels = batch['labels'].to(self.device)
|
||||||
|
|
||||||
|
outputs = self.model(input_ids=input_ids, labels=labels)
|
||||||
|
return outputs['loss']
|
||||||
|
|
||||||
|
def train_epoch(self, epoch: int):
|
||||||
|
"""Train for one epoch"""
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0
|
||||||
|
num_batches = len(self.train_dataloader)
|
||||||
|
|
||||||
|
if self.is_distributed:
|
||||||
|
self.train_dataloader.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
for step, batch in enumerate(self.train_dataloader):
|
||||||
|
# Forward pass
|
||||||
|
loss = self.compute_loss(batch)
|
||||||
|
|
||||||
|
# Backward pass
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# Gradient clipping
|
||||||
|
if self.config['max_grad_norm'] > 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(),
|
||||||
|
self.config['max_grad_norm']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer step
|
||||||
|
self.optimizer.step()
|
||||||
|
self.scheduler.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
if step % self.config['log_interval'] == 0 and self.rank == 0:
|
||||||
|
avg_loss = total_loss / (step + 1)
|
||||||
|
lr = self.scheduler.get_last_lr()[0]
|
||||||
|
|
||||||
|
print(f"Epoch {epoch}, Step {step}/{num_batches}, "
|
||||||
|
f"Loss: {loss.item():.4f}, Avg Loss: {avg_loss:.4f}, "
|
||||||
|
f"LR: {lr:.2e}")
|
||||||
|
|
||||||
|
wandb.log({
|
||||||
|
'train/loss': loss.item(),
|
||||||
|
'train/avg_loss': avg_loss,
|
||||||
|
'train/learning_rate': lr,
|
||||||
|
'train/epoch': epoch,
|
||||||
|
'train/step': step
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
if step % self.config['save_interval'] == 0 and step > 0:
|
||||||
|
self.save_checkpoint(epoch, step, loss.item())
|
||||||
|
|
||||||
|
return total_loss / num_batches
|
||||||
|
|
||||||
|
def validate(self, epoch: int):
|
||||||
|
"""Validate the model"""
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0
|
||||||
|
num_batches = len(self.val_dataloader)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step, batch in enumerate(self.val_dataloader):
|
||||||
|
loss = self.compute_loss(batch)
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
avg_loss = total_loss / num_batches
|
||||||
|
perplexity = math.exp(avg_loss)
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
print(f"Validation - Epoch {epoch}, Loss: {avg_loss:.4f}, "
|
||||||
|
f"Perplexity: {perplexity:.2f}")
|
||||||
|
|
||||||
|
wandb.log({
|
||||||
|
'val/loss': avg_loss,
|
||||||
|
'val/perplexity': perplexity,
|
||||||
|
'val/epoch': epoch
|
||||||
|
})
|
||||||
|
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""Main training loop"""
|
||||||
|
print("Starting training...")
|
||||||
|
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
|
||||||
|
for epoch in range(self.config['num_epochs']):
|
||||||
|
print(f"\n=== Epoch {epoch + 1}/{self.config['num_epochs']} ===")
|
||||||
|
|
||||||
|
# Training
|
||||||
|
train_loss = self.train_epoch(epoch)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
val_loss = self.validate(epoch)
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if val_loss < best_val_loss and self.rank == 0:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
best_model_path = Path(self.config['checkpoint_dir']) / 'best_model.pt'
|
||||||
|
model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
|
||||||
|
torch.save(model_to_save.state_dict(), best_model_path)
|
||||||
|
print(f"Best model saved with validation loss: {best_val_loss:.4f}")
|
||||||
|
|
||||||
|
# Save epoch checkpoint
|
||||||
|
self.save_checkpoint(epoch, len(self.train_dataloader), train_loss)
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
print("Training completed!")
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
class SimpleTokenizer:
|
||||||
|
"""Simple character-level tokenizer as fallback"""
|
||||||
|
def __init__(self, vocab_size: int = 50000):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.pad_token_id = 0
|
||||||
|
|
||||||
|
def encode(self, text: str) -> List[int]:
|
||||||
|
# Simple character-level encoding
|
||||||
|
return [min(ord(c), self.vocab_size - 1) for c in text]
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='DeepSeek-Math Pretraining')
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
parser.add_argument('--vocab_size', type=int, default=102400)
|
||||||
|
parser.add_argument('--hidden_size', type=int, default=4096)
|
||||||
|
parser.add_argument('--intermediate_size', type=int, default=11008)
|
||||||
|
parser.add_argument('--num_hidden_layers', type=int, default=30)
|
||||||
|
parser.add_argument('--num_attention_heads', type=int, default=32)
|
||||||
|
parser.add_argument('--max_position_embeddings', type=int, default=4096)
|
||||||
|
parser.add_argument('--max_length', type=int, default=2048)
|
||||||
|
|
||||||
|
# Training configuration
|
||||||
|
parser.add_argument('--batch_size', type=int, default=8)
|
||||||
|
parser.add_argument('--num_epochs', type=int, default=3)
|
||||||
|
parser.add_argument('--learning_rate', type=float, default=1e-4)
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=0.1)
|
||||||
|
parser.add_argument('--adam_beta1', type=float, default=0.9)
|
||||||
|
parser.add_argument('--adam_beta2', type=float, default=0.95)
|
||||||
|
parser.add_argument('--adam_epsilon', type=float, default=1e-8)
|
||||||
|
parser.add_argument('--warmup_ratio', type=float, default=0.03)
|
||||||
|
parser.add_argument('--max_grad_norm', type=float, default=1.0)
|
||||||
|
|
||||||
|
# Data paths
|
||||||
|
parser.add_argument('--train_data_path', type=str, default='data/train.jsonl')
|
||||||
|
parser.add_argument('--val_data_path', type=str, default='data/val.jsonl')
|
||||||
|
|
||||||
|
# Logging and checkpointing
|
||||||
|
parser.add_argument('--checkpoint_dir', type=str, default='checkpoints')
|
||||||
|
parser.add_argument('--log_interval', type=int, default=100)
|
||||||
|
parser.add_argument('--save_interval', type=int, default=1000)
|
||||||
|
parser.add_argument('--max_checkpoints', type=int, default=5)
|
||||||
|
parser.add_argument('--project_name', type=str, default='deepseek-math')
|
||||||
|
parser.add_argument('--run_name', type=str, default='pretraining')
|
||||||
|
|
||||||
|
# System configuration
|
||||||
|
parser.add_argument('--num_workers', type=int, default=4)
|
||||||
|
parser.add_argument('--rms_norm_eps', type=float, default=1e-6)
|
||||||
|
parser.add_argument('--rope_theta', type=float, default=10000.0)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def set_seed(seed: int = 42):
|
||||||
|
"""Set random seed for reproducibility"""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
set_seed(42)
|
||||||
|
|
||||||
|
# Convert args to config dict
|
||||||
|
config = vars(args)
|
||||||
|
|
||||||
|
# Initialize trainer
|
||||||
|
trainer = DeepSeekMathTrainer(config)
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
# Example usage with distributed training:
|
||||||
|
# torchrun --nproc_per_node=8 --nnodes=1 deepseek_math_training.py \
|
||||||
|
# --batch_size 4 \
|
||||||
|
# --learning_rate 1e-4 \
|
||||||
|
# --num_epochs 3 \
|
||||||
|
# --train_data_path /path/to/train.jsonl \
|
||||||
|
# --val_data_path /path/to/val.jsonl \
|
||||||
|
# --checkpoint_dir ./checkpoints \
|
||||||
|
# --project_name deepseek-math-7b \
|
||||||
|
# --run_name pretraining_run_1
|
||||||
|
|
||||||
|
# Data preprocessing script for mathematical datasets
|
||||||
|
def preprocess_math_data():
|
||||||
|
"""
|
||||||
|
Example preprocessing script for mathematical datasets
|
||||||
|
This would typically process:
|
||||||
|
- ArXiv papers in mathematics
|
||||||
|
- Mathematical problem-solution pairs
|
||||||
|
- Formal proofs
|
||||||
|
- Mathematical textbooks
|
||||||
|
- Code for mathematical computations
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def clean_math_text(text: str) -> str:
|
||||||
|
"""Clean and normalize mathematical text"""
|
||||||
|
# Remove excessive whitespace
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
|
||||||
|
# Normalize mathematical notation
|
||||||
|
text = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', r'(\1)/(\2)', text)
|
||||||
|
text = re.sub(r'\\sqrt\{([^}]+)\}', r'sqrt(\1)', text)
|
||||||
|
|
||||||
|
# Clean up common LaTeX commands
|
||||||
|
text = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', text)
|
||||||
|
text = re.sub(r'\\[a-zA-Z]+', '', text)
|
||||||
|
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
def process_file(input_path: str, output_path: str):
|
||||||
|
"""Process a single file and save cleaned data"""
|
||||||
|
with open(input_path, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Split into chunks (could be paragraphs, sections, etc.)
|
||||||
|
chunks = content.split('\n\n')
|
||||||
|
|
||||||
|
processed_data = []
|
||||||
|
for chunk in chunks:
|
||||||
|
if len(chunk.strip()) > 50: # Filter out very short chunks
|
||||||
|
cleaned = clean_math_text(chunk)
|
||||||
|
if cleaned:
|
||||||
|
processed_data.append({'text': cleaned})
|
||||||
|
|
||||||
|
# Save as JSONL
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
for item in processed_data:
|
||||||
|
f.write(json.dumps(item) + '\n')
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
input_dir = Path('raw_data')
|
||||||
|
output_dir = Path('processed_data')
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
for file_path in input_dir.glob('*.txt'):
|
||||||
|
output_path = output_dir / f"{file_path.stem}.jsonl"
|
||||||
|
process_file(str(file_path), str(output_path))
|
||||||
|
print(f"Processed {file_path} -> {output_path}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Uncomment to run data preprocessing
|
||||||
|
# preprocess_math_data()
|
||||||
|
|
||||||
|
# Run main training
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user