Study Notes: Stanford CS336 Language Modeling from Scratch [10]
Building a Complete Training Loop
This note documents the journey of assembling all the core components such as optimizer, learning rate scheduling, data loading, checkpointing, and decoding - into a complete training pipeline for Transformer language models. Weâll explore how each piece fits together, the design decisions behind them, and the practical considerations that make the difference between research code and production systems.
Table of Contents
- Introduction: The Big Picture
- The AdamW Optimizer: Decoupled Weight Decay Regularization
- Learning Rate Scheduling: The LLaMA Approach
- Memory-Efficient Data Loading
- Checkpoint Management
- Decoding Strategies: From Model to Text
- Putting It All Together: The Training Script
- Testing and Validation
- Key Takeaways
Introduction: The Big Picture
Training a large language model isnât just about implementing a forward pass and calling loss.backward(). A production training pipeline requires careful orchestration of multiple components, each with its own subtleties and potential pitfalls. In this note, weâll go through how to build a complete training pipeline from scratch, learning why each component matters and how they interact.
What weâll build:
- An implementation of the AdamW optimizer
- A cosine learning rate schedule with warmup, as used in LLaMA
- Memory-mapped data loading to manage loading datasets larger than RAM
- Robust checkpoint saving/loading for long training runs
- Multiple decoding strategies (temperature scaling, top-p sampling)
- A complete sample training script that ties everything together
The AdamW Optimizer: Decoupled Weight Decay Regularization
The first step in building our training loop is implementing the optimizer. While PyTorch provides torch.optim.AdamW, understanding the exact algorithm is crucial for debugging training issues and understanding why certain hyperparameters matter.
The Algorithm
The AdamW algorithm (from âDecoupled Weight Decay Regularizationâ by Loshchilov & Hutter, 2019) differs from standard Adam in how it applies weight decay. Hereâs Algorithm 2 from the paper:
Initialize:
- Learnable parameters: $\theta$
- First moment vector: $m \leftarrow 0$ (same shape as $\theta$)
- Second moment vector: $v \leftarrow 0$ (same shape as $\theta$)
For $t = 1, 2, \ldots, T$:
-
Sample batch of data $B_t$
-
Compute gradient:
\[g \leftarrow \nabla_\theta \ell(\theta; B_t)\] -
Update biased first moment estimate:
\[m \leftarrow \beta_1 m + (1 - \beta_1) g\] -
Update biased second raw moment estimate:
\[v \leftarrow \beta_2 v + (1 - \beta_2) g^2\] -
Compute bias-corrected learning rate:
\[\alpha_t \leftarrow \alpha \cdot \frac{\sqrt{1 - \beta_2^t}}{1 - \beta_1^t}\] -
Update parameters with adaptive learning rate:
\[\theta \leftarrow \theta - \alpha_t \frac{m}{\sqrt{v + \varepsilon}}\] -
Apply decoupled weight decay:
\[\theta \leftarrow \theta - \alpha \lambda \theta\]
Why Decoupled Weight Decay Matters
The key innovation in AdamW is decoupling weight decay from the gradient-based update. To understand why this matters, letâs compare the two approaches:
Standard Adam with L2 Regularization:
In traditional Adam with L2 regularization, we add the weight decay term to the gradient before computing adaptive moments:
\[g \leftarrow \nabla_\theta \ell(\theta; B_t) + \lambda \theta\]Then we proceed with the normal Adam update using this modified gradient. This means:
- Weight decay affects the adaptive moment estimates ($m$ and $v$)
- The effective weight decay depends on the adaptive learning rate
- Parameters with large gradients get less regularization (due to adaptive scaling)
AdamW with Decoupled Weight Decay:
In AdamW, we apply weight decay after the adaptive update as a separate step:
\[\theta \leftarrow \theta - \alpha_t \frac{m}{\sqrt{v + \varepsilon}} - \alpha \lambda \theta\]This decoupling means:
- Weight decay is independent of gradient statistics
- All parameters receive consistent regularization proportional to their magnitude
- Weight decay directly shrinks parameters toward zero, regardless of gradient history
Why This Improves Performance:
-
Better generalization: Decoupled weight decay provides more consistent regularization across all parameters, leading to better generalization on downstream tasks.
-
Works with large learning rates: In standard Adam + L2, increasing the learning rate also increases the effective regularization, creating unwanted coupling. AdamW removes this coupling.
-
More interpretable: The weight decay hyperparameter $\lambda$ directly controls regularization strength, making it easier to tune.
Practical Impact:
For large language models, this difference is crucial. The original BERT used Adam with L2 regularization and achieved 84.4% on MNLI. Simply switching to AdamW with the same hyperparameters improved accuracy to 84.8% - a significant gain from this single algorithmic change. Similar improvements have been observed across many other deep learning tasks.
Complete Implementation
class AdamW(torch.optim.Optimizer):
"""
Implements AdamW optimizer following Algorithm 1 from
"Decoupled Weight Decay Regularization" (Loshchilov & Hutter, 2019).
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
def step(self):
for group in self.param_groups:
beta1, beta2 = group['betas']
lr = group['lr']
weight_decay = group['weight_decay']
eps = group['eps']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
# Initialize state on first step
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p.data)
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
# Update biased first moment: m â βâm + (1 - βâ)g
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# Update biased second moment: v â βâv + (1 - βâ)g²
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
# Compute bias correction terms
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Compute adjusted learning rate: Îą_t â Îą â(1-(βâ)^t) / (1-(βâ)^t)
alpha_t = lr * math.sqrt(bias_correction2) / bias_correction1
# Update parameters: θ â θ - Îą_t m / â(v+Îľ)
denom = exp_avg_sq.sqrt().add_(eps)
p.addcdiv_(exp_avg, denom, value=-alpha_t)
# Apply decoupled weight decay: θ â θ - ιΝθ
if weight_decay != 0:
p.add_(p, alpha=-lr * weight_decay)
Usage example:
model = TransformerLM(vocab_size=50257, d_model=768, ...)
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
for batch in dataloader:
loss = compute_loss(model, batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Learning Rate Scheduling: The LLaMA Approach
Modern large language models donât use a fixed learning rate. Instead, they employ sophisticated schedules that warm up the learning rate at the start and gradually decay it during training. The LLaMA paper (Touvron et al., 2023) uses a three-phase cosine schedule that has become standard.
The Three-Phase Schedule
Phase 1 - Warmup ($t < T_w$):
\[\alpha_t = \frac{t}{T_w} \cdot \alpha_{\text{max}}\]Linear increase from 0 to $\alpha_{\text{max}}$ over $T_w$ steps.
Phase 2 - Cosine Annealing ($T_w \leq t \leq T_c$):
\[\alpha_t = \alpha_{\text{min}} + \frac{1}{2} \left(1 + \cos\left(\frac{t - T_w}{T_c - T_w} \cdot \pi\right)\right) \left(\alpha_{\text{max}} - \alpha_{\text{min}}\right)\]Smooth cosine decay from $\alpha_{\text{max}}$ to $\alpha_{\text{min}}$.
Understanding the Smooth Cosine Decay:
The beauty of cosine annealing lies in its smoothness. The diagram below shows how the learning rate evolves during the cosine annealing phase:

Figure: Cosine annealing schedule showing the smooth decay of learning rate from Îąmax to Îąmin
Breaking down the cosine formula:
Letâs denote $p = \frac{t - T_w}{T_c - T_w}$ as the progress through the cosine phase (where $p \in [0, 1]$).
The formula becomes:
\[\alpha_t = \alpha_{\text{min}} + \frac{1}{2}(1 + \cos(p \cdot \pi)) \cdot (\alpha_{\text{max}} - \alpha_{\text{min}})\]Why cosine creates a smooth curve:
- At start ($p = 0$):
- $\cos(0) = 1$
- $\alpha_t = \alpha_{\text{min}} + 1 \cdot (\alpha_{\text{max}} - \alpha_{\text{min}}) = \alpha_{\text{max}}$
- At middle ($p = 0.5$):
- $\cos(\pi/2) = 0$
- $\alpha_t = \alpha_{\text{min}} + 0.5 \cdot (\alpha_{\text{max}} - \alpha_{\text{min}})$ (halfway point)
- At end ($p = 1$):
- $\cos(\pi) = -1$
- $\alpha_t = \alpha_{\text{min}} + 0 \cdot (\alpha_{\text{max}} - \alpha_{\text{min}}) = \alpha_{\text{min}}$
Key properties of the smooth cosine decay:
- Gentle start: Derivative is near zero at $t = T_w$, creating a smooth transition from warmup
- Steepest descent: Maximum decay rate occurs at the midpoint ($p = 0.5$)
- Gentle landing: Derivative approaches zero as $t \to T_c$, allowing fine-tuning
- No discontinuities: The function and its derivative are continuous everywhere
Phase 3 - Constant ($t > T_c$):
\[\alpha_t = \alpha_{\text{min}}\]Maintain minimum learning rate.
Why This Schedule Works
Warmup phase: Starting with a small learning rate prevents the model from making destructive updates when parameters are still randomly initialized. Gradients can be large and unstable early in training, and a small learning rate provides stability.
Cosine decay: The smooth decay helps the model settle into a good minimum. The cosine schedule provides:
- Fast initial decay (when model is still far from optimum)
- Slower decay later (allowing fine-tuning)
- No sharp transitions (unlike step decay schedules)
Constant minimum: Maintaining Îą_min instead of decaying to zero allows continued (albeit slow) learning, which can be useful for very long training runs.
Implementation
def get_lr_cosine_schedule(
it: int,
max_learning_rate: float,
min_learning_rate: float,
warmup_iters: int,
cosine_cycle_iters: int,
) -> float:
"""
Get learning rate at iteration `it` using cosine schedule with warmup.
Three phases:
1. Warmup: Linear increase from 0 to max_learning_rate
2. Cosine annealing: Smooth decay from max to min learning rate
3. Constant: Maintain min_learning_rate
Args:
it: Current iteration (0-indexed)
max_learning_rate: Maximum learning rate (Îą_max)
min_learning_rate: Minimum learning rate (Îą_min)
warmup_iters: Number of warmup iterations (T_w)
cosine_cycle_iters: Total iterations for cosine cycle (T_c)
Returns:
Learning rate for current iteration
"""
# Phase 1: Warmup (t < T_w)
if it < warmup_iters:
return (it / warmup_iters) * max_learning_rate
# Phase 2: Cosine annealing (T_w ⤠t ⤠T_c)
if it <= cosine_cycle_iters:
progress = (it - warmup_iters) / (cosine_cycle_iters - warmup_iters)
cosine_decay = 0.5 * (1 + math.cos(progress * math.pi))
return min_learning_rate + cosine_decay * (max_learning_rate - min_learning_rate)
# Phase 3: Constant (t > T_c)
return min_learning_rate
Critical detail: The warmup condition is it < warmup_iters (strict inequality), not it <= warmup_iters. This ensures iteration warmup_iters is the first iteration at max_learning_rate, not the last warmup iteration.
Integration with Training Loop
for iter_num in range(max_iters):
# Get learning rate for this iteration
lr = get_lr_cosine_schedule(
it=iter_num,
max_learning_rate=1e-3,
min_learning_rate=1e-4,
warmup_iters=2000,
cosine_cycle_iters=100000,
)
# Update optimizer learning rate
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Training step
x, y = get_batch(...)
loss = model(x, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Typical hyperparameters for large models:
- Warmup: 2,000-10,000 iterations (1-5% of total training)
- Max LR: 1e-4 to 1e-3 (depends on model size; larger models use smaller LR)
- Min LR: 10% of max LR
- Cosine cycle: Total training iterations
Memory-Efficient Data Loading
When training on large text datasets (hundreds of GBs to TBs), loading the entire dataset into RAM is impossible. The solution is memory-mapped arrays using the Unix mmap system call.
The Problem
Consider training GPT-3-scale models:
- Common Crawl: ~570GB tokenized
- Books: ~150GB tokenized
- Total: ~800GB of tokens
Your machine might have 64-128GB of RAM. Loading this data is impossible.
The Solution: Memory Mapping
Memory mapping lets you âpretendâ the entire dataset is in memory, but the OS only loads the portion you actually access.
# Memory-mapped loading
dataset = np.load('train_tokens.npy', mmap_mode='r')
# This doesn't load the file into RAM!
# It creates a memory map to the file on disk
# When you access dataset[1000000:1000512],
# the OS loads just that small portion into RAM
Understanding Virtual Memory
Before diving into how memory mapping works, itâs important to understand the concept of virtual memoryâthe foundation that makes memory mapping possible.
Virtual memory is a way for your computer to make it look like you have more memory (RAM) than you actually do. It does this by using part of your disk (storage) to act as an extension of RAM. Every program âthinksâ it has access to a large, continuous block of memoryâbut behind the scenes, the operating system (OS) is moving chunks of data between RAM and disk as needed.
How Memory Mapping Works: Step by Step
-
Mapping file to memory: The system call
mmap()creates a link between a file on disk and an area in virtual memory. You can then access it like a normal array, even if the file is huge (e.g., 800GB). -
Page fault (on first access): When your code accesses something like
dataset[i], the OS sees that the data isnât in RAM yet. It triggers a page faultâa signal that tells the OS to fetch that data page from disk. -
Loading data into RAM: The OS loads the specific page (a small chunk, usually 4KB) from disk into physical RAM. Now
dataset[i]can be read directly from fast memory. -
Caching nearby elements: The OS often loads neighboring pages too (since theyâll likely be accessed soon). So if you later access
dataset[i+1], itâs already in RAMâfast! -
Eviction when RAM is full: When RAM gets full, the OS automatically evicts less-used pages (writes them back to disk if modified). This keeps the system running smoothly without running out of memory.
Key insight: Memory mapping leverages the OSâs virtual memory system to handle datasets much larger than available RAM, loading only the data you need on-demand and caching intelligently based on access patterns.
Implementation
def get_batch(
dataset: np.ndarray, # Can be memory-mapped!
batch_size: int,
context_length: int,
device: str = "cpu"
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Sample a batch of sequences from dataset.
Supports both regular arrays and memory-mapped arrays transparently.
Memory-mapped arrays use the Unix mmap system call to map files to virtual
memory, allowing you to "pretend" you have the entire dataset in memory
while only loading accessed portions on-demand.
Args:
dataset: Token array (regular or memory-mapped)
batch_size: Number of sequences to sample
context_length: Length of each sequence
device: Device to place tensors on
Returns:
x: Input sequences [batch_size, context_length]
y: Target sequences [batch_size, context_length] (shifted by 1)
"""
# Sample random start positions
max_start = len(dataset) - context_length - 1
start_indices = np.random.randint(0, max_start, size=batch_size)
# Extract sequences (this triggers page faults for memory-mapped arrays)
x = np.stack([dataset[i:i + context_length] for i in start_indices])
y = np.stack([dataset[i + 1:i + context_length + 1] for i in start_indices])
# Convert to PyTorch tensors
x = torch.from_numpy(x).long().to(device)
y = torch.from_numpy(y).long().to(device)
return x, y
def load_dataset(data_path: str, vocab_size: int) -> np.ndarray:
"""
Load dataset using memory-mapped mode for memory efficiency.
Args:
data_path: Path to .npy file containing tokenized data
vocab_size: Expected vocabulary size for validation
Returns:
Memory-mapped numpy array
"""
print(f"Loading dataset from {data_path}...")
# Load with memory mapping for large datasets
dataset = np.load(data_path, mmap_mode="r")
print(f" Loaded {len(dataset):,} tokens")
print(f" Data type: {dataset.dtype}")
print(f" Memory-mapped: {isinstance(dataset, np.memmap)}")
# Verify data integrity
max_token = dataset.max()
min_token = dataset.min()
print(f" Token range: [{min_token}, {max_token}]")
if max_token >= vocab_size:
raise ValueError(
f"Data contains token {max_token} >= vocab_size {vocab_size}. "
f"Data may be corrupted or vocab_size is incorrect."
)
if min_token < 0:
raise ValueError(f"Data contains negative token {min_token}")
print(f" â Data integrity verified")
return dataset
Important Considerations
1. Data type matching:
# Ensure dtype matches your vocabulary size
dataset = np.memmap('tokens.dat', dtype='int32', mode='r') # For vocab < 2^31
# or
dataset = np.memmap('tokens.dat', dtype='int64', mode='r') # For safety
2. Data integrity: Always verify that token values are within valid range:
assert dataset.max() < vocab_size, "Invalid token values!"
assert dataset.min() >= 0, "Negative token values!"
3. Performance tips:
- Access data sequentially when possible (better cache locality)
- Use larger batch sizes to amortize page fault overhead
- Store data on fast SSD rather than HDD
Checkpoint Management
Training large models can take days or weeks. Checkpoint management is crucial for:
- Resuming after crashes or preemption
- Evaluating models at different training stages
- Storing model configurations for reproducibility
What to Save
A complete checkpoint includes:
- Model state: All parameter values
- Optimizer state: Momentum buffers, learning rate, etc.
- Iteration count: For resuming at exact position
- Model configuration: For reconstructing architecture
Many implementations forget #4, making it hard to load models for inference later.
Why Model Configuration Matters
Think of it this way:
- Model configuration = The modelâs recipe (layer sizes, dropout rates, architecture choices)
- Model state = The modelâs learned ingredients (weights and biases)
Without the configuration, you wouldnât know how to rebuild the same model structure later.
Example: A Simple Neural Network
Letâs say you built this model in PyTorch:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_size=784, hidden_size=256, output_size=10, dropout=0.2):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(hidden_size, output_size)
When you train it, youâll want to save not only the weights, but also the model configuration:
config = {
"input_size": 784,
"hidden_size": 256,
"output_size": 10,
"dropout": 0.2
}
checkpoint = {
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"iteration": step,
"config": config
}
torch.save(checkpoint, "checkpoint.pth")
Later (for inference or resume training):
You can rebuild the model exactly the same way:
checkpoint = torch.load("checkpoint.pth")
config = checkpoint["config"]
# Rebuild model using saved configuration
model = MyModel(**config)
model.load_state_dict(checkpoint["model_state"])
This same principle applies to Transformer language models, where the configuration includes vocab_size, d_model, num_layers, num_heads, d_ff, context_length, etc.
Implementation
def save_checkpoint(
model: nn.Module,
optimizer: torch.optim.Optimizer,
iteration: int,
out: str,
model_config: dict = None,
) -> None:
"""
Save complete training state to checkpoint file.
Args:
model: Model to save
optimizer: Optimizer to save
iteration: Current training iteration
out: Output path for checkpoint
model_config: Optional model architecture configuration
"""
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'iteration': iteration,
}
# Save model config for easy loading during inference
if model_config is not None:
checkpoint['model_config'] = model_config
torch.save(checkpoint, out)
def load_checkpoint(
src: str,
model: nn.Module,
optimizer: torch.optim.Optimizer,
) -> int:
"""
Load training state from checkpoint file.
Args:
src: Path to checkpoint file
model: Model to load state into
optimizer: Optimizer to load state into
Returns:
Iteration number from checkpoint
"""
checkpoint = torch.load(src, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['iteration']
Checkpoint Strategy
During training:
# Save periodically during training
if iter_num % checkpoint_interval == 0 and iter_num > 0:
checkpoint_path = f"checkpoints/checkpoint_iter_{iter_num}.pt"
save_checkpoint(model, optimizer, iter_num, checkpoint_path, model_config)
# Save final checkpoint with both iteration number and "final" name
final_checkpoint_iter = f"checkpoints/checkpoint_iter_{max_iters}.pt"
final_checkpoint = "checkpoints/checkpoint_final.pt"
save_checkpoint(model, optimizer, max_iters, final_checkpoint_iter, model_config)
save_checkpoint(model, optimizer, max_iters, final_checkpoint, model_config)
Resuming from checkpoint:
if resume_from is not None:
start_iter = load_checkpoint(resume_from, model, optimizer)
print(f"Resumed from iteration {start_iter}")
else:
start_iter = 0
for iter_num in range(start_iter, max_iters):
# Training continues from where it left off
...
For inference (loading model configuration):
checkpoint = torch.load("checkpoint.pt")
config = checkpoint['model_config']
model = TransformerLM(
vocab_size=config['vocab_size'],
d_model=config['d_model'],
num_layers=config['num_layers'],
num_heads=config['num_heads'],
d_ff=config['d_ff'],
context_length=config['context_length'],
)
model.load_state_dict(checkpoint['model_state_dict'])
Decoding Strategies: From Model to Text
After training, your model can predict the next word given the previous ones. But you need a method to:
- Turn those predictions into probabilities
- Pick the next word/token from that probability distribution
That process is called decoding. The decoding strategy significantly impacts generation qualityâitâs the difference between coherent text and random gibberish.
Step 1: Softmax â Turning Logits into Probabilities
The model outputs a vector of logitsâraw scores for every possible token in the vocabulary. We turn these into probabilities using the softmax formula:
\[P(x_{t+1} = i \mid x_{1..t}) = \frac{e^{v_i}}{\sum_{j} e^{v_j}}\]Where:
- $v_i$ is the modelâs score (logit) for token $i$
- The numerator $e^{v_i}$ makes higher scores more likely
- The denominator $\sum_{j} e^{v_j}$ normalizes everything so probabilities sum to 1
This gives us a probability distribution over all words in the vocabulary.
Step 2: Decoding â Picking the Next Token
Now that we have probabilities, we need to choose one token to continue the text. We can:
- Pick the highest-probability token (greedy decoding) â Safe but repetitive
- Randomly sample from the probabilities â Makes text more creative
- Use other tricks to balance randomness and coherence â The strategies below
Letâs explore two powerful techniques for controlling this balance.
Temperature Scaling
Problem: Raw softmax outputs can be too peaked (always choosing the most likely token) or too flat (generating random nonsense).
Solution: Temperature scaling modifies the softmax distribution:
\[\text{softmax}(v, \tau)_i = \frac{\exp(v_i/\tau)}{\sum_{j} \exp(v_j/\tau)}\]Effects:
- $\tau < 1$: makes the distribution sharper (model becomes more confident, deterministic, greedy)
- $\tau = 1$: Standard softmax (modelâs original distribution)
- $\tau > 1$: makes the distribution flatter (model becomes more random, creative, diverse)
Implementation:
def softmax_with_temperature(
logits: torch.Tensor,
temperature: float = 1.0,
dim: int = -1
) -> torch.Tensor:
"""
Apply softmax with temperature scaling.
Args:
logits: Model output logits
temperature: Temperature parameter Ď
dim: Dimension to apply softmax
Returns:
Temperature-scaled probability distribution
"""
if temperature <= 0:
raise ValueError(f"Temperature must be positive, got {temperature}")
# Scale logits by temperature
scaled_logits = logits / temperature
# Apply softmax (numerically stable)
probs = torch.nn.functional.softmax(scaled_logits, dim=dim)
return probs
Usage:
logits = model(x, apply_softmax=False)[:, -1, :] # Get next-token logits
# Deterministic (greedy)
probs = softmax_with_temperature(logits, temperature=0.01)
# Balanced
probs = softmax_with_temperature(logits, temperature=0.8)
# Creative
probs = softmax_with_temperature(logits, temperature=1.5)
Concrete Example:
Letâs say the model predicts the next word with these raw logits and probabilities:
| Token | Raw Logit | $\tau=1.0$ (standard) | $\tau=0.5$ (sharper) | $\tau=2.0$ (flatter) |
|---|---|---|---|---|
| âcatâ | 2.5 | 0.60 | 0.94 | 0.52 |
| âdogâ | 1.0 | 0.25 | 0.05 | 0.25 |
| âbananaâ | 0.2 | 0.10 | 0.01 | 0.16 |
| âspaceshipâ | -1.5 | 0.05 | 0.00 | 0.07 |
Observations:
-
With $\tau = 0.5$ (sharper): âcatâ becomes dominant (0.94), nearly eliminating other options. The model is very confident and predictable.
-
With $\tau = 1.0$ (standard): Uses the modelâs original learned distribution. Balanced between confidence and diversity.
-
With $\tau = 2.0$ (flatter): Probabilities become more uniform. âdogâ maintains its probability, âbananaâ nearly doubles (0.10 â 0.16), and even âspaceshipâ becomes viable (0.05 â 0.07). The model is more creative and exploratory.
Top-p (Nucleus) Sampling
Problem: Even with temperature scaling, the model might assign non-zero probability to thousands of tokens, many of which are nonsensical in context.
Solution: Top-p sampling (Holtzman et al., 2020) truncates the distribution to the smallest set of tokens whose cumulative probability exceeds threshold p.
Algorithm:
Define the nucleus $V(p)$ as the smallest set such that:
\[\sum_{i \in V(p)} P(i) \geq p\]Then the filtered probability distribution is:
\[P_{\text{filtered}}(i) = \begin{cases} \frac{P(i)}{\sum_{j \in V(p)} P(j)} & \text{if } i \in V(p) \\ 0 & \text{otherwise} \end{cases}\]Implementation:
def top_p_sampling(probs: torch.Tensor, p: float = 0.9) -> torch.Tensor:
"""
Apply top-p (nucleus) sampling to probability distribution.
Args:
probs: Probability distribution [batch_size, vocab_size]
p: Cumulative probability threshold (typical: 0.9, 0.95)
Returns:
Filtered and renormalized probability distribution
"""
# Sort probabilities in descending order
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# Compute cumulative probabilities
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Find cutoff: keep tokens until cumulative prob >= p
mask = cumulative_probs <= p
# Always keep at least the top token
mask[..., 0] = True
# Zero out probabilities not in nucleus
filtered_sorted_probs = sorted_probs * mask.float()
# Scatter back to original positions
filtered_probs = torch.zeros_like(probs)
filtered_probs.scatter_(dim=-1, index=sorted_indices, src=filtered_sorted_probs)
# Renormalize
filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
return filtered_probs
Example:
# Original distribution
probs = torch.tensor([0.5, 0.3, 0.1, 0.05, 0.05])
# p=0.8: Keep top 2 tokens (0.5 + 0.3 = 0.8)
filtered = top_p_sampling(probs, p=0.8)
# Result: [0.625, 0.375, 0, 0, 0]
Applying to our earlier example:
With our âcatâ, âdogâ, âbananaâ, âspaceshipâ example, if we use $p = 0.9$:
- Sort by probability: [âcatâ (0.60), âdogâ (0.25), âbananaâ (0.10), âspaceshipâ (0.05)]
- Cumulative sum: 0.60, 0.85, 0.95, 1.00
- Keep tokens until cumulative ⼠0.9: Keep {âcatâ, âdogâ, âbananaâ}
- Remove âspaceshipâ (too low probability)
- Renormalize and sample from the remaining three tokens
Result: The model only samples from {âcatâ, âdogâ, âbananaâ}, avoiding the extremely unlikely âspaceshipâ.
Summary: Putting It All Together
| Step | Purpose | Key Parameter |
|---|---|---|
| Softmax | Turns model logits into probabilities | None |
| Temperature | Controls confidence vs. creativity | $\tau$ (typical: 0.7-1.5) |
| Top-p Sampling | Limits randomness to most probable tokens | $p$ (typical: 0.9-0.95) |
Recommended combinations:
- Factual tasks: $\tau = 0.1$ (nearly greedy)
- Balanced generation: $\tau = 0.8$, $p = 0.9$
- Creative writing: $\tau = 1.2$, $p = 0.95$
Autoregressive Decoding
Putting it together for text generation:
def decode(
model: nn.Module,
prompt_tokens: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 1.0,
top_p: float = None,
eos_token_id: int = None,
device: str = "cpu",
) -> torch.Tensor:
"""
Generate text autoregressively from a prompt.
Args:
model: Trained TransformerLM
prompt_tokens: Initial prompt [batch_size, seq_len]
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling threshold (None to disable)
eos_token_id: End-of-sequence token for early stopping
device: Device to run on
Returns:
Generated sequence [batch_size, seq_len + num_generated]
"""
model.eval()
if prompt_tokens.dim() == 1:
prompt_tokens = prompt_tokens.unsqueeze(0)
generated = prompt_tokens.to(device)
with torch.no_grad():
for _ in range(max_new_tokens):
# Get logits for next token
logits = model(generated, apply_softmax=False)
next_token_logits = logits[:, -1, :]
# Apply temperature scaling
next_token_probs = softmax_with_temperature(
next_token_logits,
temperature=temperature
)
# Apply top-p filtering if requested
if top_p is not None:
next_token_probs = top_p_sampling(next_token_probs, p=top_p)
# Sample next token
next_token = torch.multinomial(next_token_probs, num_samples=1)
# Append to sequence
generated = torch.cat([generated, next_token], dim=1)
# Check for EOS
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
Putting It All Together: The Training Script
Now we assemble all components into a production training script. The key is making everything configurable via command-line arguments.
Command-Line Interface
def parse_args():
parser = argparse.ArgumentParser(description="Train a Transformer language model")
# Data
parser.add_argument("--train_data", type=str, required=True)
parser.add_argument("--val_data", type=str, required=True)
parser.add_argument("--vocab_size", type=int, required=True)
# Model architecture
parser.add_argument("--d_model", type=int, default=768)
parser.add_argument("--num_layers", type=int, default=12)
parser.add_argument("--num_heads", type=int, default=12)
parser.add_argument("--d_ff", type=int, default=3072)
parser.add_argument("--context_length", type=int, default=512)
# Training
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--max_iters", type=int, default=100000)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
# Optimizer
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--min_lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0.1)
parser.add_argument("--grad_clip", type=float, default=1.0)
# Learning rate schedule
parser.add_argument("--warmup_iters", type=int, default=2000)
parser.add_argument("--lr_decay_iters", type=int, default=100000)
# Logging and checkpointing
parser.add_argument("--eval_interval", type=int, default=500)
parser.add_argument("--log_interval", type=int, default=10)
parser.add_argument("--checkpoint_interval", type=int, default=5000)
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
# Resume
parser.add_argument("--resume_from", type=str, default=None)
return parser.parse_args()
Understanding Key Training Parameters
Before diving into the training loop, letâs clarify two important hyperparameters that control how training progresses:
max_iters (Maximum Iterations)
The total number of training steps (iterations) to run.
One iteration = one forward pass + one backward pass + one optimizer step
for iter_num in range(max_iters): # e.g., 100,000 steps
x, y = get_batch(...)
loss = model(x, y)
loss.backward()
optimizer.step()
Example:
- If
max_iters = 100,000andbatch_size = 32: - Model will train for 100,000 steps
- Each step processes 32 examples
- Total examples seen = 100,000 Ă 32 = 3,200,000 (with repetition if dataset is smaller)
gradient_accumulation_steps (Gradient Accumulation)
The number of mini-batches to accumulate gradients over before updating weights.
Why use it? To simulate larger batch sizes when GPU memory is limited.
Without gradient accumulation (standard training):
# Effective batch size = 32
x, y = get_batch(batch_size=32)
loss = model(x, y)
loss.backward() # Compute gradients
optimizer.step() # Update weights immediately
With gradient accumulation (e.g., gradient_accumulation_steps = 4):
# Effective batch size = 32 Ă 4 = 128
total_loss = 0.0
for _ in range(4): # Accumulate over 4 mini-batches
x, y = get_batch(batch_size=32)
loss = model(x, y)
loss = loss / 4 # Scale loss to average over accumulation
loss.backward() # Accumulate gradients (don't update yet!)
total_loss += loss.item()
optimizer.step() # Now update with accumulated gradients
optimizer.zero_grad()
Key benefits:
- Simulate larger batches: Want batch_size=128 but only have memory for 32? Use
gradient_accumulation_steps=4 - Effective batch size =
batch_size Ă gradient_accumulation_steps - Smoother gradients: Larger effective batches lead to more stable training
Main Training Loop
def main():
args = parse_args()
# Create checkpoint directory
os.makedirs(args.checkpoint_dir, exist_ok=True)
# Load datasets with memory mapping
train_data = load_dataset(args.train_data, args.vocab_size)
val_data = load_dataset(args.val_data, args.vocab_size)
# Initialize model
model = TransformerLM(
vocab_size=args.vocab_size,
context_length=args.context_length,
d_model=args.d_model,
num_layers=args.num_layers,
num_heads=args.num_heads,
d_ff=args.d_ff,
).to(args.device)
# Store model configuration for checkpoints
model_config = {
'vocab_size': args.vocab_size,
'd_model': args.d_model,
'num_layers': args.num_layers,
'num_heads': args.num_heads,
'd_ff': args.d_ff,
'context_length': args.context_length,
}
# Initialize optimizer
optimizer = AdamW(
model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
)
# Resume from checkpoint if specified
start_iter = 0
if args.resume_from:
start_iter = load_checkpoint(args.resume_from, model, optimizer)
print(f"Resumed from iteration {start_iter}")
# Training loop
model.train()
for iter_num in range(start_iter, args.max_iters):
# Get learning rate for this iteration
lr = get_lr_cosine_schedule(
iter_num,
max_learning_rate=args.lr,
min_learning_rate=args.min_lr,
warmup_iters=args.warmup_iters,
cosine_cycle_iters=args.lr_decay_iters,
)
# Update learning rate in optimizer
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Training step with gradient accumulation
total_loss = 0.0
for _ in range(args.gradient_accumulation_steps):
x, y = get_batch(train_data, args.batch_size, args.context_length, args.device)
logits = model(x, apply_softmax=False)
loss = cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss = loss / args.gradient_accumulation_steps
loss.backward()
total_loss += loss.item()
# Gradient clipping
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
# Optimizer step
optimizer.step()
optimizer.zero_grad()
# Logging
if iter_num % args.log_interval == 0:
print(f"[{iter_num}/{args.max_iters}] loss: {total_loss:.4f} | lr: {lr:.2e}")
# Evaluation
if iter_num % args.eval_interval == 0:
val_loss = evaluate(model, val_data, args)
print(f"[{iter_num}] val_loss: {val_loss:.4f}")
# Save checkpoint
if iter_num % args.checkpoint_interval == 0 and iter_num > 0:
checkpoint_path = os.path.join(args.checkpoint_dir, f"checkpoint_iter_{iter_num}.pt")
save_checkpoint(model, optimizer, iter_num, checkpoint_path, model_config)
# Save final checkpoint
final_checkpoint = os.path.join(args.checkpoint_dir, "checkpoint_final.pt")
save_checkpoint(model, optimizer, args.max_iters, final_checkpoint, model_config)
Usage
# Train from scratch
python -m cs336_basics.train \
--train_data data/train.npy \
--val_data data/val.npy \
--vocab_size 50257 \
--d_model 768 \
--num_layers 12 \
--num_heads 12 \
--d_ff 3072 \
--batch_size 32 \
--max_iters 100000 \
--lr 1e-3 \
--warmup_iters 2000
# Resume from checkpoint
python -m cs336_basics.train \
--train_data data/train.npy \
--val_data data/val.npy \
--vocab_size 50257 \
--resume_from checkpoints/checkpoint_iter_50000.pt
Testing and Validation
Production systems require comprehensive testing. Hereâs how we can validate our training pipeline to ensure correctness before launching expensive, multi-day training runs.
Unit Tests for Components
Each component should have its own unit tests to verify correctness in isolation.
Test AdamW Optimizer:
def test_adamw():
"""Test AdamW matches reference implementation."""
import torch.nn as nn
# Create simple model
model = nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=0.01, weight_decay=0.1)
# Create dummy data
x = torch.randn(4, 10)
y = torch.randn(4, 5)
# Training step
loss = ((model(x) - y) ** 2).mean()
loss.backward()
optimizer.step()
# Verify weights were updated
assert loss.item() > 0 # Loss should be non-zero
# Compare against PyTorch's implementation for exact match
Test Learning Rate Schedule:
def test_learning_rate_schedule():
"""Test learning rate schedule matches specification."""
max_lr = 1.0
min_lr = 0.1
warmup_iters = 100
cosine_cycle_iters = 1000
# Test warmup phase
lr_start = get_lr_cosine_schedule(0, max_lr, min_lr, warmup_iters, cosine_cycle_iters)
assert lr_start == 0.0, "LR should start at 0"
lr_mid_warmup = get_lr_cosine_schedule(50, max_lr, min_lr, warmup_iters, cosine_cycle_iters)
assert abs(lr_mid_warmup - 0.5 * max_lr) < 1e-6, "LR should be halfway at warmup midpoint"
lr_end_warmup = get_lr_cosine_schedule(100, max_lr, min_lr, warmup_iters, cosine_cycle_iters)
assert abs(lr_end_warmup - max_lr) < 1e-6, "LR should be max at end of warmup"
# Test cosine phase
lr_mid_cosine = get_lr_cosine_schedule(550, max_lr, min_lr, warmup_iters, cosine_cycle_iters)
assert min_lr < lr_mid_cosine < max_lr, "LR should be decaying in cosine phase"
lr_end_cosine = get_lr_cosine_schedule(1000, max_lr, min_lr, warmup_iters, cosine_cycle_iters)
assert abs(lr_end_cosine - min_lr) < 1e-6, "LR should be min at end of cosine"
# Test constant phase
lr_after = get_lr_cosine_schedule(1500, max_lr, min_lr, warmup_iters, cosine_cycle_iters)
assert lr_after == min_lr, "LR should remain at min after cosine phase"
Test Top-p Sampling:
def test_top_p_sampling():
"""Test top-p sampling filters correctly."""
import torch
probs = torch.tensor([[0.5, 0.3, 0.1, 0.05, 0.05]])
filtered = top_p_sampling(probs, p=0.8)
# Should keep only top 2 tokens (0.5 + 0.3 = 0.8)
assert (filtered[0, :2] > 0).all(), "Top 2 tokens should have non-zero probability"
assert (filtered[0, 2:] == 0).all(), "Remaining tokens should be filtered out"
# Should be renormalized
assert torch.allclose(filtered.sum(), torch.tensor(1.0)), "Probabilities should sum to 1"
# Check renormalization is correct
expected = torch.tensor([[0.625, 0.375, 0.0, 0.0, 0.0]])
assert torch.allclose(filtered, expected, atol=1e-3), "Renormalization should be correct"
Test Temperature Scaling:
def test_temperature_scaling():
"""Test temperature scaling affects distribution correctly."""
import torch
logits = torch.tensor([[2.0, 1.0, 0.0]])
# Standard softmax
probs_normal = softmax_with_temperature(logits, temperature=1.0)
# Low temperature (sharper)
probs_sharp = softmax_with_temperature(logits, temperature=0.5)
assert probs_sharp[0, 0] > probs_normal[0, 0], "Low temp should increase max probability"
# High temperature (flatter)
probs_flat = softmax_with_temperature(logits, temperature=2.0)
assert probs_flat[0, 0] < probs_normal[0, 0], "High temp should decrease max probability"
Integration Test
Test the entire training pipeline end-to-end with a small synthetic dataset.
def test_training_integration():
"""End-to-end test of training pipeline."""
import tempfile
import subprocess
import os
import numpy as np
# Create small synthetic dataset
vocab_size = 1000
train_data = np.random.randint(0, vocab_size, size=10000, dtype=np.int64)
val_data = np.random.randint(0, vocab_size, size=2000, dtype=np.int64)
# Save to temporary files
with tempfile.TemporaryDirectory() as tmpdir:
train_path = os.path.join(tmpdir, "train.npy")
val_path = os.path.join(tmpdir, "val.npy")
checkpoint_dir = os.path.join(tmpdir, "checkpoints")
np.save(train_path, train_data)
np.save(val_path, val_data)
os.makedirs(checkpoint_dir)
# Run training for 10 iterations
result = subprocess.run([
"python", "-m", "cs336_basics.train",
"--train_data", train_path,
"--val_data", val_path,
"--vocab_size", str(vocab_size),
"--d_model", "128",
"--num_layers", "2",
"--num_heads", "4",
"--d_ff", "512",
"--max_iters", "10",
"--checkpoint_interval", "10",
"--checkpoint_dir", checkpoint_dir,
], check=True, capture_output=True, text=True)
# Verify checkpoint was created
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_final.pt")
assert os.path.exists(checkpoint_path), "Final checkpoint should be created"
# Test checkpoint loading
checkpoint = torch.load(checkpoint_path)
assert "model_state_dict" in checkpoint
assert "optimizer_state_dict" in checkpoint
assert "iteration" in checkpoint
assert checkpoint["iteration"] == 10
print("â Training ran successfully and created checkpoint")
# Test resumption from checkpoint
result = subprocess.run([
"python", "-m", "cs336_basics.train",
"--train_data", train_path,
"--val_data", val_path,
"--vocab_size", str(vocab_size),
"--d_model", "128",
"--num_layers", "2",
"--max_iters", "15",
"--checkpoint_dir", checkpoint_dir,
"--resume_from", checkpoint_path,
], check=True, capture_output=True, text=True)
# Verify training continued from iteration 10
assert "Resumed from iteration 10" in result.stdout
print("â Training resumed successfully from checkpoint")
Pre-Training Validation Checklist
Before launching a long training run, verify:
- Loss decreases on small data: Train for 100 iterations on a tiny dataset and verify loss goes down
- Checkpoints save/load correctly: Save and load a checkpoint, verify iteration count and loss match
- Learning rate schedule looks correct: Plot the LR over iterations and verify the curve matches expectations
- Memory usage is reasonable: Monitor GPU memory and ensure it doesnât exceed available capacity
- Data loading works: Verify data batches have correct shape and token values are in valid range
- Gradient norms are stable: Log gradient norms during warmup, verify they decrease and donât explode
Quick validation script:
# Quick 100-iteration validation run
python -m cs336_basics.train \
--train_data data/train.npy \
--val_data data/val.npy \
--vocab_size 50257 \
--max_iters 100 \
--log_interval 10 \
--checkpoint_interval 50
# Expected output:
# [0/100] loss: 10.8234 | lr: 0.00e+00 (high initial loss)
# [10/100] loss: 9.2156 | lr: 5.00e-05 (loss decreasing)
# [50/100] loss: 7.8901 | lr: 2.50e-04 (loss continuing to decrease)
# [100/100] loss: 6.5432 | lr: 5.00e-04 (loss still decreasing)
If loss doesnât decrease in 100 iterations, something is wrongâdebug before launching a long run!
Key Takeaways
Building a production training pipeline requires attention to many details beyond the core model architecture. Here are the essential lessons:
1. Correctness Over Convenience
Follow paper specifications exactly, especially for:
- Optimizer algorithms (AdamWâs decoupled weight decay)
- Learning rate schedules (strict inequalities matter)
- Bias correction formulas
Small deviations can cause subtle training instabilities that only appear after days of training.
2. Memory Efficiency Is Critical
For large-scale training:
- Use memory-mapped arrays for datasets larger than RAM
- Monitor peak memory usage during training
- Consider gradient checkpointing for very large models
3. Checkpoint Everything
A complete checkpoint includes:
- Model parameters
- Optimizer state (momentum buffers!)
- Iteration count
- Model configuration
- Random seeds (for reproducibility)
Donât learn this lesson the hard way after losing a week of training.
4. Make Everything Configurable
Use command-line arguments for all hyperparameters:
- Enables systematic hyperparameter sweeps
- Makes it easy to resume with different settings
- Documents what settings were used
5. Test Before Long Training Runs
- Run integration tests on small synthetic data
- Train for 100 iterations and verify:
- Loss decreases
- Checkpoints save/load correctly
- Learning rate schedule looks correct
- Memory usage is reasonable
A 10-minute test can save days of wasted compute.
6. Generation Quality Depends on Decoding
Even a well-trained model can produce poor text with bad decoding settings:
- Start with
temperature=0.8, top_p=0.9 - Adjust based on task (lower temperature for factual, higher for creative)
- Always use some form of sampling (greedy decoding produces repetitive text)
7. Monitor Training Actively
Log frequently and watch for:
- Loss spikes (may indicate learning rate too high)
- Loss plateaus (may need more data or capacity)
- Gradient norms (should decrease during warmup)
- Generation samples (qualitative assessment)
8. Production Code Is Different
Research code can get away with:
- Hardcoded hyperparameters
- No checkpointing
- Single-file scripts
Production code needs:
- Configuration management
- Robust error handling
- Comprehensive logging
- Restart/resume capability
This note covered the engineering necessary to turn research ideas into a working system. The components we builtâAdamW optimizer, cosine schedule, memory-mapped data loading, checkpointing, and decoding strategiesâform the foundation of modern LLM training pipelines. These same patterns appear in systems like GPT-3, LLaMA, and other large language models.
The next step is scaling: distributed training across multiple GPUs, larger datasets, and bigger models. But the fundamentals remain the same: correct implementations of proven algorithms, careful attention to numerical stability, and robust engineering practices.