Training Pipeline - From Random Weights to Intelligence
"Training transforms random weights into intelligent systems through three stages: pre-training, fine-tuning, and alignment."
Modern LLMs are not trained end-to-end in a single pass. They undergo a multi-stage pipeline where each stage builds upon the previous one: pre-training creates the base model, SFT (Supervised Fine-Tuning) teaches instruction-following, and RLHF/DPO aligns the model with human preferences. This document covers each stage in detail, including the algorithms, data requirements, and implementation considerations for production-scale LLM training.
Training Stages Overview
Stage Comparison
| Stage | Data | Objective | Tokens | Cost | Result |
|---|---|---|---|---|---|
| Pre-training | Web text | Next-token prediction | 1T-3T | ~$2M | Base model |
| SFT | Instructions | Instruction-following | 10M-100M | ~$10K | Chat-capable |
| RLHF/DPO | Comparisons | Preference alignment | 1M-10M | ~$5K | Aligned behavior |
Pre-Training
Next-Token Prediction Objective
The fundamental training objective for all modern LLMs:
L = -sum_{t=1}^{T} log P(x_t | x_1, x_2, ..., x_{t-1})
For a sequence of tokens, the model learns to maximize the probability of each token given all previous tokens. This simple objective, applied at scale, enables the emergence of complex reasoning, world knowledge, and linguistic capabilities.
Why Next-Token Prediction Works
The power of next-token prediction comes from:
- Scale: Training on trillions of tokens exposes the model to diverse patterns
- Context: Predicting the next token requires understanding the full context
- Compression: The model learns efficient representations of language structure
- Generalization: Patterns learned generalize to unseen combinations
Data Processing Pipeline
import re
from typing import List, Tuple
from collections import defaultdict
class TextPreprocessor:
"""
Preprocess text data for LLM training.
Handles deduplication, quality filtering, and privacy removal.
"""
def __init__(self, min_length: int = 128, max_length: int = 4096):
self.min_length = min_length
self.max_length = max_length
# Patterns for privacy removal
self.email_pattern = re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b')
self.phone_pattern = re.compile(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b')
self.ssn_pattern = re.compile(r'\b\d{3}-\d{2}-\d{4}\b')
self.ip_pattern = re.compile(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b')
def remove_pii(self, text: str) -> str:
"""Remove personally identifiable information."""
text = self.email_pattern.sub('[EMAIL]', text)
text = self.phone_pattern.sub('[PHONE]', text)
text = self.ssn_pattern.sub('[SSN]', text)
text = self.ip_pattern.sub('[IP]', text)
return text
def check_quality(self, text: str) -> bool:
"""
Check text quality based on heuristics.
Returns True if text passes quality checks.
"""
# Length check
words = text.split()
if len(words) < self.min_length or len(words) > self.max_length:
return False
# Mean word length (reject very short/long words)
mean_word_len = sum(len(w) for w in words) / len(words)
if mean_word_len < 3 or mean_word_len > 10:
return False
# Special character ratio (too many = garbage)
special_ratio = sum(1 for c in text if not c.isalnum() and not c.isspace()) / len(text)
if special_ratio > 0.3:
return False
# Repetition check (detect "aaaaaaa..." patterns)
if len(set(words)) / len(words) < 0.2:
return False
return True
def deduplicate_by_ngram(self, texts: List[str], n: int = 13) -> List[str]:
"""
Remove near-duplicates using n-gram overlap.
Uses MinHash-like approach for efficiency.
"""
seen_ngrams = set()
unique_texts = []
for text in texts:
words = text.split()
if len(words) < n:
continue
# Sample n-grams
ngrams = [' '.join(words[i:i+n]) for i in range(0, len(words) - n, n)]
# Check if any n-gram seen before
if not any(ngram in seen_ngrams for ngram in ngrams[:5]):
seen_ngrams.update(ngrams)
unique_texts.append(text)
return unique_texts
# Usage
preprocessor = TextPreprocessor(min_length=128, max_length=4096)
# Process a batch of texts
raw_texts = [
"This is a sample document with some contact@example.com content...",
"Another document with quality issues" * 10, # Too repetitive
]
clean_texts = []
for text in raw_texts:
text = preprocessor.remove_pii(text)
if preprocessor.check_quality(text):
clean_texts.append(text)
print(f"Processed {len(clean_texts)}/{len(raw_texts)} texts")
Curriculum Learning Strategy
Training proceeds through carefully designed curriculum stages:
from enum import Enum
from dataclasses import dataclass
class TrainingStage(Enum):
"""Curriculum learning stages for pre-training."""
FOUNDATION = "foundation" # High-quality, diverse text
KNOWLEDGE = "knowledge" # Focused on factual content
REASONING = "reasoning" # Logical reasoning patterns
SYNTHESIS = "synthesis" # Multi-step problem solving
@dataclass
class StageConfig:
"""Configuration for a training stage."""
stage: TrainingStage
data_proportion: float # Proportion of total data
learning_rate: float
batch_size: int
duration_steps: int
description: str
CURRICULUM = [
StageConfig(
stage=TrainingStage.FOUNDATION,
data_proportion=0.40,
learning_rate=3e-4,
batch_size=512,
duration_steps=400000,
description="Build foundational language understanding"
),
StageConfig(
stage=TrainingStage.KNOWLEDGE,
data_proportion=0.30,
learning_rate=2e-4,
batch_size=512,
duration_steps=300000,
description="Acquire world knowledge and facts"
),
StageConfig(
stage=TrainingStage.REASONING,
data_proportion=0.20,
learning_rate=1.5e-4,
batch_size=512,
duration_steps=200000,
description="Develop reasoning capabilities"
),
StageConfig(
stage=TrainingStage.SYNTHESIS,
data_proportion=0.10,
learning_rate=1e-4,
batch_size=512,
duration_steps=100000,
description="Integrate skills for complex tasks"
),
]
def get_curriculum_lr(step: int, curriculum: list[StageConfig]) -> float:
"""Get learning rate based on curriculum stage."""
total_steps = sum(c.duration_steps for c in curriculum)
current_step = step % total_steps
cumulative_steps = 0
for config in curriculum:
if cumulative_steps <= current_step < cumulative_steps + config.duration_steps:
return config.learning_rate
cumulative_steps += config.duration_steps
return curriculum[-1].learning_rate
Python Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
def compute_language_model_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute cross-entropy loss for language modeling.
Args:
logits: Model output of shape (batch, seq_len, vocab_size)
targets: Target token IDs of shape (batch, seq_len)
"""
# Flatten for cross-entropy
batch_size, seq_len, vocab_size = logits.shape
logits_flat = logits.view(-1, vocab_size)
targets_flat = targets.view(-1)
# Compute cross-entropy loss
loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=-100)
return loss
# Example
batch_size, seq_len, vocab_size = 2, 128, 50000
logits = torch.randn(batch_size, seq_len, vocab_size)
targets = torch.randint(0, vocab_size, (batch_size, seq_len))
loss = compute_language_model_loss(logits, targets)
print(f"Loss: {loss.item():.4f}")
# Typical pre-training loss: 2.0-4.0 (before training)
# Converges to ~1.8-2.5 for good models
Mixed Precision Training
Mixed precision training uses FP16/BF16 for computations while maintaining FP32 master weights:
import torch
from torch.cuda.amp import autocast, GradScaler
class MixedPrecisionTrainer:
"""
Trainer with mixed precision support.
Uses BF16 when available for better numerical stability.
"""
def __init__(self, model, optimizer, device='cuda'):
self.model = model.to(device)
self.optimizer = optimizer
self.device = device
# Check BF16 support
if torch.cuda.is_bf16_supported():
self.dtype = torch.bfloat16
print("Using BF16 for training")
else:
self.dtype = torch.float16
self.scaler = GradScaler()
print("Using FP16 with GradScaler for training")
def training_step(self, batch):
"""Single training step with mixed precision."""
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
targets = input_ids[:, 1:].contiguous()
self.optimizer.zero_grad()
if self.dtype == torch.bfloat16:
# BF16 doesn't need gradient scaling
with autocast(dtype=torch.bfloat16):
logits = self.model(input_ids[:, :-1], attention_mask[:, :-1])
loss = compute_language_model_loss(logits, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
else:
# FP16 with gradient scaling
with autocast(dtype=torch.float16):
logits = self.model(input_ids[:, :-1], attention_mask[:, :-1])
loss = compute_language_model_loss(logits, targets)
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
return loss.item()
FlashAttention Integration
FlashAttention is a memory-efficient attention mechanism that's essential for large-scale training:
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
print("FlashAttention not available, using standard attention")
class FlashAttentionBlock(nn.Module):
"""
Transformer block with FlashAttention.
Memory efficient: O(N) instead of O(N^2) for attention.
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, attention_mask=None):
"""
Args:
x: (batch, seq_len, d_model)
attention_mask: (batch, seq_len) or None
"""
# Self-attention with FlashAttention
residual = x
x = self.norm1(x)
if FLASH_ATTENTION_AVAILABLE:
# FlashAttention expects (batch, seq_len, 3, heads, head_dim)
batch_size, seq_len, _ = x.shape
qkv = self.qkv_proj(x).reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
# FlashAttention requires causal mask to be handled separately
x = flash_attn_func(q, k, v, causal=True)
x = self.out_proj(x.reshape(batch_size, seq_len, -1))
else:
# Fallback to standard attention
qkv = self.qkv_proj(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
# Causal mask
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
attn = attn.masked_fill(causal_mask.to(x.device), float('-inf'))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(batch_size, seq_len, -1)
x = self.out_proj(x)
x = x + residual
# FFN
residual = x
x = self.norm2(x)
x = self.ffn(x) + residual
return x
Advanced Optimizer Configuration
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
def create_optimizer_and_scheduler(model, config):
"""
Create optimizer and learning rate scheduler with warmup.
Uses AdamW with cosine decay and linear warmup.
"""
# Separate parameters for weight decay
# Don't apply weight decay to bias, layer norm, and embedding parameters
no_decay = ['bias', 'layer_norm.weight', 'lm_head.weight']
optimizer_grouped_parameters = [
{
'params': [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay)],
'weight_decay': config.weight_decay,
},
{
'params': [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
},
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=config.learning_rate,
betas=(config.beta1, config.beta2),
eps=1e-8,
)
# Warmup scheduler
warmup_scheduler = LinearLR(
optimizer,
start_factor=0.0,
end_factor=1.0,
total_iters=config.warmup_steps
)
# Cosine decay scheduler
cosine_scheduler = CosineAnnealingLR(
optimizer,
T_max=config.max_steps - config.warmup_steps,
eta_min=config.learning_rate * config.min_lr_ratio
)
# Sequential scheduler: warmup then cosine decay
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[config.warmup_steps]
)
return optimizer, scheduler
# Usage
optimizer, scheduler = create_optimizer_and_scheduler(model, config)
for step, batch in enumerate(dataloader):
loss = pretrain_step(model, optimizer, config, batch)
scheduler.step()
if step % 100 == 0:
current_lr = scheduler.get_last_lr()[0]
print(f"Step {step}: loss={loss:.4f}, lr={current_lr:.2e}")
Scaling Laws
The Chinchilla scaling laws (Hoffmann et al., 2022) established optimal compute allocation:
| Parameter Count | Optimal Training Tokens | Compute (FLOPs) |
|---|---|---|
| 1B | 20B | 1.6e19 |
| 7B | 1.4T | 8.2e20 |
| 70B | 1.4T | 8.2e21 |
| 400B | 3T+ | 3e23 |
Key insight: Models should be trained on ~20 tokens per parameter for optimal performance.
Scaling Law Formula
L(N, D) = E + A/N^alpha + B/D^beta
Where:
Lis the lossNis the parameter countDis the data size (tokens)E,A,B,alpha,betaare fitted constants
def chinchilla_loss(params: float, tokens: float) -> float:
"""
Approximate Chinchilla scaling law.
Args:
params: Number of parameters (billions)
tokens: Training tokens (trillions)
"""
E = 1.69 # Irreducible loss
A = 406.4
B = 998.1
alpha = 0.34
beta = 0.28
loss = E + A / (params ** alpha) + B / (tokens ** beta)
return loss
# Find optimal data for 7B model
params_7b = 7
optimal_tokens_7b = 20 * params_7b # Chinchilla: ~20 tokens per parameter
loss_7b = chinchilla_loss(params_7b, optimal_tokens_7b / 1000)
print(f"Optimal tokens for 7B: {optimal_tokens_7b}B, Loss: {loss_7b:.3f}")
Training Configuration Example
from dataclasses import dataclass
from typing import Optional
@dataclass
class PreTrainingConfig:
"""Configuration for LLM pre-training."""
# Model architecture
d_model: int = 4096
num_heads: int = 32
num_layers: int = 32
d_ff: int = 10952 # 8/3 * d_model for SwiGLU
# Training hyperparameters
batch_size: int = 512 # Global batch size
micro_batch_size: int = 4 # Per-GPU batch size
learning_rate: float = 3e-4
weight_decay: float = 0.1
beta1: float = 0.9
beta2: float = 0.95
# Learning rate schedule
warmup_steps: int = 2000
max_steps: int = 1000000
min_lr_ratio: float = 0.1
# Data
vocab_size: int = 128000
max_seq_len: int = 4096
def get_lr(self, step: int) -> float:
"""Cosine learning rate schedule with warmup."""
if step < self.warmup_steps:
return self.learning_rate * step / self.warmup_steps
progress = (step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
return self.min_lr_ratio * self.learning_rate + (1 - self.min_lr_ratio) * self.learning_rate * cosine_decay
# Training loop skeleton
def pretrain_step(model, optimizer, config, batch):
"""Single pre-training step."""
input_ids = batch['input_ids'] # (batch, seq_len)
attention_mask = batch['attention_mask']
# Forward pass
logits = model(input_ids, attention_mask)
# Compute loss (shift for next-token prediction)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
loss = compute_language_model_loss(shift_logits, shift_labels)
# Backward pass
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step
optimizer.step()
optimizer.zero_grad()
return loss.item()
Emergent Abilities
Capabilities that emerge at scale without explicit training:
| Ability | Emerges At | Description |
|---|---|---|
| In-context learning | ~10B+ | Learn from examples in prompt |
| Chain-of-thought | ~30B+ | Multi-step reasoning |
| Instruction-following | ~7B+ (with SFT) | Understand and follow directions |
| Code generation | ~7B+ | Write and debug code |
| Multi-lingual | ~7B+ | Cross-lingual transfer |
Important: Emergent abilities are not guaranteed - they depend on training data and architecture choices.