Study Notes: Stanford CS336 Language Modeling from Scratch [3]
Building a BPE Tokenizer from Scratch: Train the Tokenizer using TinyStories Dataset
Ever wondered how modern language models like GPT break down text into tokens? In this note, I will share how to build a Byte Pair Encoding (BPE) tokenizer from scratch and train it on the TinyStories Dataset. We will see how BPE achieves impressive compression ratios.
What is BPE Tokenization?
Byte Pair Encoding (BPE) is a compression algorithm that’s become the backbone of modern tokenization. Here’s how it works:
- Start with bytes: Every character becomes its byte representation (0-255)
- Find frequent pairs: Look for the most common pair of adjacent tokens
- Merge and repeat: Replace the most frequent pair with a new token, then repeat
A Simple Example
Let’s say we have the word “hello” appearing many times in our text:
- Initially:
h-e-l-l-o
(5 tokens) - If “l-l” is the most frequent pair, merge it:
h-e-ll-o
(4 tokens) - If “e-ll” becomes frequent, merge it:
h-ell-o
(3 tokens)
This process creates a vocabulary that efficiently represents common patterns in your text. Check out my previous post for a brief introduction.
The TinyStories Dataset
We’ll train our tokenizer on TinyStories, a fascinating dataset of short stories written using only words that 3-4 year olds typically understand. These stories were generated by GPT-3.5 and GPT-4, making them perfect for experimenting with tokenization.
Downloading the Data
First, let’s download the TinyStories froom Huggingface:
!mkdir -p data
!cd data
!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt
!wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
!gunzip owt_train.txt.gz
!wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
!gunzip owt_valid.txt.gz
!cd ..
Challenge: Parallelizing Pre-tokenization
The TinyStories dataset is big (over 2GB), which presents a challenge for tokenizer training. We need to:
- Process the file in parallel for speed
- Ensure we don’t split tokens incorrectly at chunk boundaries
Solution: Smart Chunking with Special Tokens
Our solution uses special tokens (like <|endoftext|>
) as natural boundaries for splitting the file.
Simple Example: Let’s say we have a text file containing: “Hello<SPLIT>
World<SPLIT>
How<SPLIT>
Are<SPLIT>
You”, special split token is <SPLIT>
, and we want to divide the text into 3 chunks.
Here’s one implementation for intelligent file chunking as shared in the cs336 lecture notes:
import os
from typing import BinaryIO
def find_chunk_boundaries( file: BinaryIO,
desired_num_chunks: int,
split_special_token: bytes)->list[int]:
"""
Chunk the file into parts that can be counted independently.
May return fewer chunks if the boundaries end up overlapping.
"""
assert isinstance(split_special_token, bytes),(
"Must represent special token as a bytestring"
)
# Get total file size in bytes
file.seek(0, os.SEEK_END)
file_size = file.tell()
file.seek(0)
chunk_size = file_size // desired_num_chunks
# Initial guesses for chunk boundary locations, uniformly spaced
# Chunks start on previous index, don't include last index
chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
print(f"Initial guess of the chunk boundaries: {chunk_boundaries}")
chunk_boundaries[-1] = file_size
mini_chunk_size = 4096 # Read ahead by 4k bytes at a time
for bi in range(1, len(chunk_boundaries) - 1):
initial_position = chunk_boundaries[bi]
file.seek(initial_position) # Start at boundary guess
while True:
mini_chunk = file.read(mini_chunk_size) # Read a mini chunk
# If EOF, this boundary should be at the end of the file
if mini_chunk == b"":
chunk_boundaries[bi] = file_size
break
# Find the special token in the mini chunk
found_at = mini_chunk.find(split_special_token)
if found_at != -1:
chunk_boundaries[bi] = initial_position + found_at
break
initial_position += mini_chunk_size
# Make sure all boundaries are unique, but might be fewer than desired_num_chunks
return sorted(set(chunk_boundaries))
Testing Our Chunking Algorithm
Let’s see how this works with a concrete example:
import io
def demonstrate_chunk_boundaries():
"""Demonstrate how to use find_chunk_boundaries with a practical example."""
# Create sample data - our example text
sample_text = "Hello<SPLIT>World<SPLIT>How<SPLIT>Are<SPLIT>You"
sample_bytes = sample_text.encode('utf-8')
print("=== Original Data ===")
print(f"Text: {sample_text}")
print(f"Bytes: {sample_bytes}")
print(f"Total size: {len(sample_bytes)} bytes")
print()
# Create a file-like object from our sample data
file_obj = io.BytesIO(sample_bytes)
# Define our split token
split_token = b"<SPLIT>"
desired_chunks = 3
print("=== Finding Chunk Boundaries ===")
print(f"Desired number of chunks: {desired_chunks}")
print(f"Split token: {split_token}")
print()
# Find the chunk boundaries
boundaries = find_chunk_boundaries(file_obj, desired_chunks, split_token)
print(f"Final boundaries: {boundaries}")
print(f"Number of chunks created: {len(boundaries) - 1}")
print()
# Demonstrate how to use the boundaries to read chunks
print("=== Reading Chunks ===")
file_obj.seek(0) # Reset file pointer
for i in range(len(boundaries) - 1):
start_pos = boundaries[i]
end_pos = boundaries[i + 1]
chunk_size = end_pos - start_pos
# Read the chunk
file_obj.seek(start_pos)
chunk_data = file_obj.read(chunk_size)
chunk_text = chunk_data.decode('utf-8')
print(f"Chunk {i + 1}:")
print(f" Position: bytes {start_pos}-{end_pos-1}")
print(f" Size: {chunk_size} bytes")
print(f" Content: '{chunk_text}'")
print(f" Raw bytes: {chunk_data}")
print()
Running this demonstration:
demonstrate_chunk_boundaries()
Output:
=== Original Data ===
Text: Hello<SPLIT>World<SPLIT>How<SPLIT>Are<SPLIT>You
Bytes: b'Hello<SPLIT>World<SPLIT>How<SPLIT>Are<SPLIT>You'
Total size: 47 bytes
=== Finding Chunk Boundaries ===
Desired number of chunks: 3
Split token: b'<SPLIT>'
Initial guess of the chunk boundaries: [0, 15, 30, 45]
Final boundaries: [0, 17, 37, 47]
Number of chunks created: 3
=== Reading Chunks ===
Chunk 1:
Position: bytes 0-16
Size: 17 bytes
Content: 'Hello<SPLIT>World'
Raw bytes: b'Hello<SPLIT>World'
Chunk 2:
Position: bytes 17-36
Size: 20 bytes
Content: '<SPLIT>How<SPLIT>Are'
Raw bytes: b'<SPLIT>How<SPLIT>Are'
Chunk 3:
Position: bytes 37-46
Size: 10 bytes
Content: '<SPLIT>You'
Raw bytes: b'<SPLIT>You'
Notice how the algorithm automatically adjusted the boundaries to align with <SPLIT>
tokens, ensuring clean chunk separation.
BPE Training Implementation
Now implement the core BPE training algorithm. The implementation shared here handles parallel processing, special tokens, and efficient pair counting.
Core Training Function
Here’s is my complete BPE training implementation:
import re
import os
import multiprocessing as mp
from collections import defaultdict, Counter
from typing import Dict, List, Tuple, BinaryIO
# Simplified GPT-2-style regex pattern for pre-tokenization (using standard re module)
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?[a-zA-ZÀ-ÿ]+| ?[0-9]+| ?[^\s\w]+|\s+(?!\S)|\s+"""
def process_chunk(args):
"""Process a chunk of the file and return word counts."""
start, end, input_path, special_tokens = args
word_counts = defaultdict(int)
with open(input_path, 'rb') as f:
f.seek(start)
chunk = f.read(end - start).decode('utf-8', errors='ignore')
# Split on special tokens to prevent merging across boundaries
if special_tokens:
pattern = '|'.join(re.escape(token) for token in special_tokens)
text_segments = re.split(f'({pattern})', chunk)
else:
text_segments = [chunk]
for segment in text_segments:
if segment in special_tokens:
continue # Skip special tokens during counting
# Apply GPT-2 regex pattern
for match in re.finditer(GPT2_SPLIT_PATTERN, segment):
token_text = match.group()
token_bytes = tuple(token_text.encode('utf-8'))
word_counts[token_bytes] += 1
return word_counts
def train_bpe_tokenizer(input_path: str, vocab_size: int, special_tokens: list[str], verbose: bool = True) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
"""
Train a byte-level Byte Pair Encoding (BPE) tokenizer from a text file.
Args:
input_path: Path to the input text file containing training data
vocab_size: Maximum size of the final vocabulary (includes initial bytes + special tokens + merges)
special_tokens: List of special token strings to include in vocabulary
verbose: Whether to print training progress information
Returns:
vocab: Complete tokenizer vocabulary mapping token IDs to byte sequences
merges: Ordered list of BPE merge operations performed during training
"""
import time
# Initialize vocabulary with bytes 0-255
vocab = {i: bytes([i]) for i in range(256)}
next_id = 256
# Add special tokens to vocabulary
for token in special_tokens:
token_bytes = token.encode('utf-8')
vocab[next_id] = token_bytes
next_id += 1
if verbose:
print("Step 1: Setting up parallel processing...")
# Get chunk boundaries for multiprocessing
num_processes = mp.cpu_count()
if verbose:
print(f"Using {num_processes} processes for parallel tokenization")
with open(input_path, 'rb') as f:
if special_tokens:
# Use first special token for chunking boundaries
if verbose:
print(f"Finding chunk boundaries aligned with special token: {special_tokens[0]}")
boundaries = find_chunk_boundaries(f, num_processes, special_tokens[0].encode('utf-8'))
else:
# Use simple chunking without special token alignment
f.seek(0, os.SEEK_END)
file_size = f.tell()
chunk_size = file_size // num_processes
boundaries = [i * chunk_size for i in range(num_processes + 1)]
boundaries[-1] = file_size
if verbose:
print(f"File size: {file_size:,} bytes, chunk size: {chunk_size:,} bytes")
if verbose:
print(f"Created {len(boundaries)-1} chunks for processing")
print("\nStep 2: Pre-tokenizing text corpus...")
# Process chunks in parallel
chunk_args = []
for start, end in zip(boundaries[:-1], boundaries[1:]):
chunk_args.append((start, end, input_path, special_tokens))
start_time = time.time()
with mp.Pool(processes=num_processes) as pool:
chunk_results = pool.map(process_chunk, chunk_args)
tokenization_time = time.time() - start_time
if verbose:
print(f"Pre-tokenization completed in {tokenization_time:.2f} seconds")
# Merge results from all chunks
word_counts = defaultdict(int)
total_tokens = 0
for chunk_result in chunk_results:
for word, count in chunk_result.items():
word_counts[word] += count
total_tokens += count
if verbose:
print(f"Found {len(word_counts):,} unique word types")
print(f"Total token count: {total_tokens:,}")
print(f"Most common words:")
sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
for i, (word_bytes, count) in enumerate(sorted_words[:5]):
try:
word_str = bytes(word_bytes).decode('utf-8', errors='replace')
print(f" {i+1}. '{word_str}' -> {count:,} times")
except:
print(f" {i+1}. {word_bytes} -> {count:,} times")
# Convert to working format for BPE (list of byte values)
word_freq = {}
for word_bytes, freq in word_counts.items():
word_tokens = list(word_bytes) # Convert to list of ints
word_freq[tuple(word_tokens)] = freq
merges = []
pair_index = {} # Efficient indexing for pair counting
def update_pair_index(word_freq, pair_index):
"""Update the pair index for efficient counting."""
pair_index.clear()
for word, freq in word_freq.items():
for i in range(len(word) - 1):
pair = (word[i], word[i + 1])
if pair not in pair_index:
pair_index[pair] = []
pair_index[pair].append((word, i, freq))
def count_pairs(pair_index):
"""Count pair frequencies efficiently using the index."""
pair_counts = defaultdict(int)
for pair, occurrences in pair_index.items():
total_count = sum(freq for _, _, freq in occurrences)
pair_counts[pair] = total_count
return pair_counts
# BPE training loop
target_merges = vocab_size - len(vocab)
if verbose:
print(f"\nStep 3: Training BPE with {target_merges:,} merges...")
print(f"Initial vocabulary size: {len(vocab)} (256 bytes + {len(special_tokens)} special tokens)")
print("=" * 60)
bpe_start_time = time.time()
for merge_num in range(target_merges):
merge_step_start = time.time()
# Update pair index
update_pair_index(word_freq, pair_index)
# Count pairs efficiently
pair_counts = count_pairs(pair_index)
if not pair_counts:
if verbose:
print(f"No more pairs to merge at step {merge_num + 1}")
break
# Find most frequent pair (with lexicographic tiebreaking)
best_pair = max(pair_counts.items(), key=lambda x: (x[1], x[0]))[0]
best_count = pair_counts[best_pair]
# Create new token for merge
new_token_id = next_id
next_id += 1
# Get the byte sequences for the tokens being merged
left_bytes = vocab[best_pair[0]]
right_bytes = vocab[best_pair[1]]
# Record merge as byte sequences
merges.append((left_bytes, right_bytes))
# Update vocabulary - merge the two byte sequences
vocab[new_token_id] = left_bytes + right_bytes
# Update word frequencies by applying merge
new_word_freq = {}
for word, freq in word_freq.items():
new_word = []
i = 0
while i < len(word):
if (i < len(word) - 1 and
word[i] == best_pair[0] and
word[i + 1] == best_pair[1]):
new_word.append(new_token_id)
i += 2
else:
new_word.append(word[i])
i += 1
new_word_tuple = tuple(new_word)
if new_word_tuple in new_word_freq:
new_word_freq[new_word_tuple] += freq
else:
new_word_freq[new_word_tuple] = freq
word_freq = new_word_freq
merge_step_time = time.time() - merge_step_start
# Progress logging
if verbose:
if (merge_num + 1) % 100 == 0 or merge_num < 10 or (merge_num + 1) % 1000 == 0:
try:
left_str = left_bytes.decode('utf-8', errors='replace')
right_str = right_bytes.decode('utf-8', errors='replace')
merged_str = (left_bytes + right_bytes).decode('utf-8', errors='replace')
print(f"Merge {merge_num + 1:4d}/{target_merges}: "
f"'{left_str}' + '{right_str}' -> '{merged_str}' "
f"(freq: {best_count:,}, time: {merge_step_time:.3f}s)")
except:
print(f"Merge {merge_num + 1:4d}/{target_merges}: "
f"{left_bytes} + {right_bytes} -> {left_bytes + right_bytes} "
f"(freq: {best_count:,}, time: {merge_step_time:.3f}s)")
bpe_time = time.time() - bpe_start_time
if verbose:
print("=" * 60)
print(f"BPE training completed in {bpe_time:.2f} seconds")
print(f"Final vocabulary size: {len(vocab)}")
print(f"Total merges performed: {len(merges)}")
# Show compression statistics
if word_counts:
original_tokens = sum(len(bytes(word_bytes)) for word_bytes, count in word_counts.items() for _ in range(count))
compressed_tokens = sum(len(word) for word, count in word_freq.items() for _ in range(count))
compression_ratio = original_tokens / compressed_tokens if compressed_tokens > 0 else 1.0
print(f"Compression ratio: {compression_ratio:.2f}x (from {original_tokens:,} to {compressed_tokens:,} tokens)")
return vocab, merges
def save_tokenizer(vocab: dict[int, bytes], merges: list[tuple[bytes, bytes]],
vocab_path: str, merges_path: str):
"""Save vocabulary and merges to disk files."""
import json
import pickle
# Save vocabulary
with open(vocab_path, 'wb') as f:
pickle.dump(vocab, f)
# Save merges
with open(merges_path, 'wb') as f:
pickle.dump(merges, f)
def load_tokenizer(vocab_path: str, merges_path: str) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
"""Load vocabulary and merges from disk files."""
import pickle
with open(vocab_path, 'rb') as f:
vocab = pickle.load(f)
with open(merges_path, 'rb') as f:
merges = pickle.load(f)
return vocab, merges
Training on TinyStories Dataset
Now let’s use our implementation to train a tokenizer on the TinyStories dataset. Here is one training function to demonstrate all the steps:
import time
import os
def train_bpe_tokentizer_via_dataset(input_path: str):
print("=" * 80)
print("BPE TOKENIZER TRAINING ON TINYSTORIES DATASET")
print("=" * 80)
# Configuration
vocab_size = 10000
special_tokens = ["<|endoftext|>"]
# Check if input file exists
if not os.path.exists(input_path):
print(f"Error: Input file '{input_path}' not found!")
print("Please ensure the TinyStories dataset is in the data/ directory.")
return
# Display configuration
file_size = os.path.getsize(input_path)
print(f"Configuration:")
print(f" Input file: {input_path}")
print(f" File size: {file_size:,} bytes ({file_size / 1024 / 1024:.1f} MB)")
print(f" Target vocabulary size: {vocab_size:,}")
print(f" Special tokens: {special_tokens}")
print(f" Verbose logging: Enabled")
print()
# Train the tokenizer with verbose output
overall_start_time = time.time()
vocab, merges = train_bpe_tokenizer(
input_path=input_path,
vocab_size=vocab_size,
special_tokens=special_tokens,
verbose=True # Enable detailed logging
)
overall_end_time = time.time()
print("\n" + "=" * 80)
print("TRAINING SUMMARY")
print("=" * 80)
print(f"Total training time: {overall_end_time - overall_start_time:.2f} seconds")
print(f"Final vocabulary size: {len(vocab):,}")
print(f"Number of merges performed: {len(merges):,}")
print(f"Actual vocab size vs target: {len(vocab)} / {vocab_size}")
# Save the tokenizer
vocab_path = "tinystories_vocab.pkl"
merges_path = "tinystories_merges.pkl"
print(f"\nSaving tokenizer to disk...")
save_tokenizer(vocab, merges, vocab_path, merges_path)
print(f" ✓ Vocabulary saved to: {vocab_path}")
print(f" ✓ Merges saved to: {merges_path}")
# Detailed vocabulary analysis
print("\n" + "=" * 80)
print("VOCABULARY ANALYSIS")
print("=" * 80)
# Count different types of tokens
byte_tokens = sum(1 for token_id in vocab.keys() if token_id < 256)
special_token_count = len(special_tokens)
merged_tokens = len(vocab) - byte_tokens - special_token_count
print(f"Token type breakdown:")
print(f" Byte tokens (0-255): {byte_tokens}")
print(f" Special tokens: {special_token_count}")
print(f" Merged tokens: {merged_tokens}")
print(f" Total: {len(vocab)}")
# Show some vocabulary examples
print(f"\nByte tokens (first 10):")
for i in range(10):
if i in vocab:
char = vocab[i].decode('utf-8', errors='replace')
if char.isprintable() and char != ' ':
print(f" Token {i:3d}: {vocab[i]} -> '{char}'")
else:
print(f" Token {i:3d}: {vocab[i]} -> {repr(char)}")
print(f"\nSpecial tokens:")
for token_str in special_tokens:
token_bytes = token_str.encode('utf-8')
for token_id, vocab_bytes in vocab.items():
if vocab_bytes == token_bytes:
print(f" Token {token_id:3d}: {vocab_bytes} -> '{token_str}'")
break
print(f"\nMost recently merged tokens (last 10):")
merged_token_ids = [tid for tid in sorted(vocab.keys()) if tid >= 256 + len(special_tokens)]
for token_id in merged_token_ids[-10:]:
try:
decoded = vocab[token_id].decode('utf-8', errors='replace')
print(f" Token {token_id:4d}: {vocab[token_id]} -> '{decoded}'")
except:
print(f" Token {token_id:4d}: {vocab[token_id]} -> (non-UTF8)")
print(f"\nFirst 10 merge operations:")
for i, (left, right) in enumerate(merges[:10]):
try:
left_str = left.decode('utf-8', errors='replace')
right_str = right.decode('utf-8', errors='replace')
merged_str = (left + right).decode('utf-8', errors='replace')
print(f" Merge {i+1:2d}: '{left_str}' + '{right_str}' -> '{merged_str}'")
except:
print(f" Merge {i+1:2d}: {left} + {right} -> (binary)")
print(f"\nLast 10 merge operations:")
for i, (left, right) in enumerate(merges[-10:], len(merges) - 9):
try:
left_str = left.decode('utf-8', errors='replace')
right_str = right.decode('utf-8', errors='replace')
merged_str = (left + right).decode('utf-8', errors='replace')
print(f" Merge {i:2d}: '{left_str}' + '{right_str}' -> '{merged_str}'")
except:
print(f" Merge {i:2d}: {left} + {right} -> (binary)")
# Show file sizes
vocab_size_bytes = os.path.getsize(vocab_path)
merges_size_bytes = os.path.getsize(merges_path)
print(f"\nOutput file sizes:")
print(f" Vocabulary file: {vocab_size_bytes:,} bytes ({vocab_size_bytes / 1024:.1f} KB)")
print(f" Merges file: {merges_size_bytes:,} bytes ({merges_size_bytes / 1024:.1f} KB)")
print(f" Total: {vocab_size_bytes + merges_size_bytes:,} bytes ({(vocab_size_bytes + merges_size_bytes) / 1024:.1f} KB)")
print("\n" + "=" * 80)
print("TRAINING COMPLETED SUCCESSFULLY!")
print("=" * 80)
print("You can now use the trained tokenizer for encoding/decoding text.")
print(f"Load with: vocab, merges = load_tokenizer('{vocab_path}', '{merges_path}')")
To run the training, one can try for example:
train_bpe_tokentizer_via_dataset(input_path="data/TinyStoriesV2-GPT4-train.txt")
And it will output the following info from the training process:
================================================================================
BPE TOKENIZER TRAINING ON TINYSTORIES DATASET
================================================================================
Configuration:
Input file: /content/TinyStoriesV2-GPT4-train.txt
File size: 2,227,753,162 bytes (2124.6 MB)
Target vocabulary size: 10,000
Special tokens: ['<|endoftext|>']
Verbose logging: Enabled
Step 1: Setting up parallel processing...
Using 12 processes for parallel tokenization
Finding chunk boundaries aligned with special token: <|endoftext|>
Initial guess of the chunk boundaries: [0, 185646096, 371292192, 556938288, 742584384, 928230480, 1113876576, 1299522672, 1485168768, 1670814864, 1856460960, 2042107056, 2227753152]
Created 12 chunks for processing
Step 2: Pre-tokenizing text corpus...
Pre-tokenization completed in 66.52 seconds
Found 59,904 unique word types
Total token count: 536,592,162
Most common words:
1. '.' -> 41,764,519 times
2. ',' -> 23,284,331 times
3. ' the' -> 20,828,576 times
4. ' and' -> 19,475,966 times
5. ' a' -> 15,063,529 times
Step 3: Training BPE with 9,743 merges...
Initial vocabulary size: 257 (256 bytes + 1 special tokens)
============================================================
Merge 1/9743: ' ' + 't' -> ' t' (freq: 63,482,199, time: 0.273s)
Merge 2/9743: 'h' + 'e' -> 'he' (freq: 63,341,860, time: 0.318s)
Merge 3/9743: ' ' + 'a' -> ' a' (freq: 47,465,635, time: 0.340s)
Merge 4/9743: ' ' + 's' -> ' s' (freq: 32,362,158, time: 0.340s)
Merge 5/9743: ' ' + 'w' -> ' w' (freq: 31,485,643, time: 0.327s)
Merge 6/9743: 'n' + 'd' -> 'nd' (freq: 28,922,386, time: 0.332s)
Merge 7/9743: ' t' + 'he' -> ' the' (freq: 28,915,024, time: 0.320s)
Merge 8/9743: 'e' + 'd' -> 'ed' (freq: 24,836,456, time: 0.317s)
Merge 9/9743: ' ' + 'b' -> ' b' (freq: 22,147,488, time: 0.326s)
Merge 10/9743: ' t' + 'o' -> ' to' (freq: 20,892,273, time: 0.322s)
Merge 100/9743: ' ha' + 'pp' -> ' happ' (freq: 3,147,884, time: 0.251s)
Merge 200/9743: ' s' + 'e' -> ' se' (freq: 1,410,130, time: 0.343s)
Merge 300/9743: ' s' + 'omet' -> ' somet' (freq: 790,510, time: 0.245s)
Merge 400/9743: ' g' + 'ot' -> ' got' (freq: 524,776, time: 0.338s)
Merge 500/9743: ' e' + 'ach' -> ' each' (freq: 369,637, time: 0.321s)
Merge 600/9743: 'l' + 'f' -> 'lf' (freq: 279,566, time: 0.230s)
Merge 700/9743: ' wal' + 'k' -> ' walk' (freq: 221,114, time: 0.237s)
Merge 800/9743: ' do' + 'll' -> ' doll' (freq: 177,602, time: 0.324s)
Merge 900/9743: ' ' + 'G' -> ' G' (freq: 147,699, time: 0.214s)
Merge 1000/9743: 'ec' + 't' -> 'ect' (freq: 127,288, time: 0.233s)
Merge 1100/9743: ' l' + 'ight' -> ' light' (freq: 108,006, time: 0.208s)
Merge 1200/9743: ' d' + 'in' -> ' din' (freq: 92,211, time: 0.225s)
Merge 1300/9743: ' picture' + 's' -> ' pictures' (freq: 80,416, time: 0.318s)
Merge 1400/9743: 'itt' + 'en' -> 'itten' (freq: 68,466, time: 0.235s)
Merge 1500/9743: 'A' + 'my' -> 'Amy' (freq: 59,829, time: 0.306s)
Merge 1600/9743: ' tal' + 'king' -> ' talking' (freq: 53,781, time: 0.330s)
Merge 1700/9743: 'b' + 'all' -> 'ball' (freq: 48,005, time: 0.309s)
Merge 1800/9743: ' k' + 'iss' -> ' kiss' (freq: 43,477, time: 0.318s)
...
Merge 8000/9743: ' mom' + 'mies' -> ' mommies' (freq: 879, time: 0.205s)
Merge 8100/9743: ' cryst' + 'als' -> ' crystals' (freq: 840, time: 0.299s)
Merge 8200/9743: ' playd' + 'ate' -> ' playdate' (freq: 809, time: 0.283s)
Merge 8300/9743: ' support' + 'ing' -> ' supporting' (freq: 778, time: 0.200s)
Merge 8400/9743: ' activ' + 'ity' -> ' activity' (freq: 747, time: 0.300s)
Merge 8500/9743: 'L' + 'izzy' -> 'Lizzy' (freq: 716, time: 0.284s)
Merge 8600/9743: 'er' + 'ing' -> 'ering' (freq: 691, time: 0.311s)
Merge 8700/9743: ' tid' + 'ied' -> ' tidied' (freq: 660, time: 0.308s)
Merge 8800/9743: 'f' + 'lowers' -> 'flowers' (freq: 633, time: 0.295s)
Merge 8900/9743: ' Gra' + 'nd' -> ' Grand' (freq: 609, time: 0.299s)
Merge 9000/9743: ' frustr' + 'ation' -> ' frustration' (freq: 584, time: 0.301s)
Merge 9100/9743: 'amil' + 'iar' -> 'amiliar' (freq: 561, time: 0.205s)
Merge 9200/9743: ' P' + 'retty' -> ' Pretty' (freq: 542, time: 0.310s)
Merge 9300/9743: ' sal' + 'on' -> ' salon' (freq: 521, time: 0.292s)
Merge 9400/9743: ' p' + 'ounced' -> ' pounced' (freq: 502, time: 0.196s)
Merge 9500/9743: ' pops' + 'ic' -> ' popsic' (freq: 485, time: 0.185s)
Merge 9600/9743: ' pain' + 'ful' -> ' painful' (freq: 469, time: 0.298s)
Merge 9700/9743: 'solut' + 'ely' -> 'solutely' (freq: 454, time: 0.308s)
============================================================
BPE training completed in 2731.72 seconds
Final vocabulary size: 10000
Total merges performed: 9743
Compression ratio: 4.07x (from 2,192,422,648 to 538,511,097 tokens)
================================================================================
TRAINING SUMMARY
================================================================================
Total training time: 2898.45 seconds
Final vocabulary size: 10,000
Number of merges performed: 9,743
Actual vocab size vs target: 10000 / 10000
Saving tokenizer to disk...
✓ Vocabulary saved to: tinystories_vocab.pkl
✓ Merges saved to: tinystories_merges.pkl
================================================================================
VOCABULARY ANALYSIS
================================================================================
Token type breakdown:
Byte tokens (0-255): 256
Special tokens: 1
Merged tokens: 9743
Total: 10000
Byte tokens (first 10):
Token 0: b'\x00' -> '\x00'
Token 1: b'\x01' -> '\x01'
Token 2: b'\x02' -> '\x02'
Token 3: b'\x03' -> '\x03'
Token 4: b'\x04' -> '\x04'
Token 5: b'\x05' -> '\x05'
Token 6: b'\x06' -> '\x06'
Token 7: b'\x07' -> '\x07'
Token 8: b'\x08' -> '\x08'
Token 9: b'\t' -> '\t'
Special tokens:
Token 256: b'<|endoftext|>' -> '<|endoftext|>'
Most recently merged tokens (last 10):
Token 9990: b' improving' -> ' improving'
Token 9991: b' nicest' -> ' nicest'
Token 9992: b' whiskers' -> ' whiskers'
Token 9993: b' booth' -> ' booth'
Token 9994: b' Land' -> ' Land'
Token 9995: b'Rocky' -> 'Rocky'
Token 9996: b' meadows' -> ' meadows'
Token 9997: b' Starry' -> ' Starry'
Token 9998: b' imaginary' -> ' imaginary'
Token 9999: b' bold' -> ' bold'
First 10 merge operations:
Merge 1: ' ' + 't' -> ' t'
Merge 2: 'h' + 'e' -> 'he'
Merge 3: ' ' + 'a' -> ' a'
Merge 4: ' ' + 's' -> ' s'
Merge 5: ' ' + 'w' -> ' w'
Merge 6: 'n' + 'd' -> 'nd'
Merge 7: ' t' + 'he' -> ' the'
Merge 8: 'e' + 'd' -> 'ed'
Merge 9: ' ' + 'b' -> ' b'
Merge 10: ' t' + 'o' -> ' to'
Last 10 merge operations:
Merge 9734: ' impro' + 'ving' -> ' improving'
Merge 9735: ' nice' + 'st' -> ' nicest'
Merge 9736: ' wh' + 'iskers' -> ' whiskers'
Merge 9737: ' bo' + 'oth' -> ' booth'
Merge 9738: ' L' + 'and' -> ' Land'
Merge 9739: 'Rock' + 'y' -> 'Rocky'
Merge 9740: ' meadow' + 's' -> ' meadows'
Merge 9741: ' St' + 'arry' -> ' Starry'
Merge 9742: ' imag' + 'inary' -> ' imaginary'
Merge 9743: ' bo' + 'ld' -> ' bold'
Output file sizes:
Vocabulary file: 117,701 bytes (114.9 KB)
Merges file: 109,714 bytes (107.1 KB)
Total: 227,415 bytes (222.1 KB)
================================================================================
TRAINING COMPLETED SUCCESSFULLY!
================================================================================
You can now use the trained tokenizer for encoding/decoding text.
Load with: vocab, merges = load_tokenizer('tinystories_vocab.pkl', 'tinystories_merges.pkl')
Using the Trained Tokenizer
Once we have a trained tokenizer, we need a class to encode and decode text. Here’s one complete implementation:
class SimpleBPETokenizer:
"""Simple BPE tokenizer for encoding/decoding text."""
def __init__(self, vocab, merges, special_tokens=None):
self.vocab = vocab # {token_id: bytes}
self.merges = merges # [(left_bytes, right_bytes), ...]
self.special_tokens = special_tokens or ["<|endoftext|>"]
# Create reverse mapping for decoding
self.id_to_bytes = vocab
self.bytes_to_id = {v: k for k, v in vocab.items()}
# GPT-2 style regex pattern
self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?[a-zA-ZÀ-ÿ]+| ?[0-9]+| ?[^\s\w]+|\s+(?!\S)|\s+"""
# Build merge rules for faster encoding
self.merge_rules = {}
for i, (left_bytes, right_bytes) in enumerate(merges):
# Find what tokens these bytes correspond to
left_id = self.bytes_to_id.get(left_bytes)
right_id = self.bytes_to_id.get(right_bytes)
merged_bytes = left_bytes + right_bytes
merged_id = self.bytes_to_id.get(merged_bytes)
if left_id is not None and right_id is not None and merged_id is not None:
self.merge_rules[(left_id, right_id)] = merged_id
def encode(self, text: str) -> list[int]:
"""Encode text to token IDs."""
if not text:
return []
# Handle special tokens
token_ids = []
remaining_text = text
# Split on special tokens first
for special_token in self.special_tokens:
if special_token in remaining_text:
parts = remaining_text.split(special_token)
new_parts = []
for i, part in enumerate(parts):
if i > 0:
# Add special token
special_bytes = special_token.encode('utf-8')
special_id = self.bytes_to_id.get(special_bytes)
if special_id is not None:
token_ids.append(special_id)
if part:
new_parts.append(part)
remaining_text = ''.join(new_parts)
# Apply regex tokenization
for match in re.finditer(self.pattern, remaining_text):
word = match.group()
word_tokens = self._encode_word(word)
token_ids.extend(word_tokens)
return token_ids
def _encode_word(self, word: str) -> list[int]:
"""Encode a single word using BPE merges."""
# Start with individual bytes
word_bytes = word.encode('utf-8')
tokens = []
# Convert each byte to its token ID
for byte_val in word_bytes:
tokens.append(byte_val) # Byte token IDs are 0-255
# Apply BPE merges
while len(tokens) > 1:
# Find the best merge to apply
best_merge = None
best_pos = -1
best_merge_priority = float('inf')
for i in range(len(tokens) - 1):
pair = (tokens[i], tokens[i + 1])
if pair in self.merge_rules:
# Find merge priority (earlier merges have higher priority)
merged_bytes = self.id_to_bytes[tokens[i]] + self.id_to_bytes[tokens[i + 1]]
for j, (left_bytes, right_bytes) in enumerate(self.merges):
if left_bytes + right_bytes == merged_bytes:
if j < best_merge_priority:
best_merge = pair
best_pos = i
best_merge_priority = j
break
if best_merge is None:
break
# Apply the best merge
new_tokens = tokens[:best_pos]
new_tokens.append(self.merge_rules[best_merge])
new_tokens.extend(tokens[best_pos + 2:])
tokens = new_tokens
return tokens
def decode(self, token_ids: list[int]) -> str:
"""Decode token IDs back to text."""
if not token_ids:
return ""
# Convert token IDs to bytes
byte_sequences = []
for token_id in token_ids:
if token_id in self.id_to_bytes:
byte_sequences.append(self.id_to_bytes[token_id])
else:
# Handle unknown tokens
byte_sequences.append(b'<UNK>')
# Concatenate all bytes and decode
all_bytes = b''.join(byte_sequences)
try:
return all_bytes.decode('utf-8', errors='replace')
except:
return all_bytes.decode('utf-8', errors='ignore')
def tokenize_with_details(self, text: str):
"""Tokenize text and show detailed breakdown."""
token_ids = self.encode(text)
print(f"Original text: '{text}'")
print(f"Length: {len(text)} characters")
print(f"UTF-8 bytes: {len(text.encode('utf-8'))} bytes")
print(f"Token count: {len(token_ids)} tokens")
print(f"Compression ratio: {len(text.encode('utf-8')) / len(token_ids):.2f}x")
print()
print("Token breakdown:")
for i, token_id in enumerate(token_ids):
token_bytes = self.id_to_bytes[token_id]
try:
token_str = token_bytes.decode('utf-8', errors='replace')
if token_str.isprintable():
print(f" {i+1:2d}. Token {token_id:4d}: '{token_str}' ({len(token_bytes)} bytes)")
else:
print(f" {i+1:2d}. Token {token_id:4d}: {repr(token_str)} ({len(token_bytes)} bytes)")
except:
print(f" {i+1:2d}. Token {token_id:4d}: {token_bytes} (binary)")
# Verify round-trip
decoded = self.decode(token_ids)
print(f"\nDecoded text: '{decoded}'")
print(f"Round-trip successful: {text == decoded}")
return token_ids
Let us compose some simple test cases below:
def test_bpe_tokenizer():
print("=" * 60)
print("BPE TOKENIZER SAMPLE TESTS")
print("=" * 60)
# Load the trained tokenizer
try:
vocab, merges = load_tokenizer('tinystories_vocab.pkl', 'tinystories_merges.pkl')
print(f"✓ Loaded tokenizer with {len(vocab)} vocab entries and {len(merges)} merges")
except FileNotFoundError:
print("Error: Tokenizer files not found!")
print("Please run the training script first to create 'tinystories_vocab.pkl' and 'tinystories_merges.pkl'")
return
# Create tokenizer instance
tokenizer = SimpleBPETokenizer(vocab, merges)
print()
# Example 1: Simple sentence
print("EXAMPLE 1: Simple sentence")
print("-" * 30)
text1 = "Hello world! How are you today?"
tokenizer.tokenize_with_details(text1)
print()
# Example 2: Text with special token
print("EXAMPLE 2: Text with special token")
print("-" * 30)
text2 = "Once upon a time<|endoftext|>The end."
tokenizer.tokenize_with_details(text2)
print()
# Example 3: Repeated words (should compress well)
print("EXAMPLE 3: Repeated words")
print("-" * 30)
text3 = "the the the cat cat sat sat on on the the mat mat"
tokenizer.tokenize_with_details(text3)
print()
# Example 4: Numbers and punctuation
print("EXAMPLE 4: Numbers and punctuation")
print("-" * 30)
text4 = "I have 123 apples, 456 oranges, and 789 bananas!"
tokenizer.tokenize_with_details(text4)
print()
# Example 5: Just encoding/decoding
print("EXAMPLE 5: Simple encode/decode")
print("-" * 30)
text5 = "This is a test."
token_ids = tokenizer.encode(text5)
decoded_text = tokenizer.decode(token_ids)
print(f"Original: '{text5}'")
print(f"Token IDs: {token_ids}")
print(f"Decoded: '{decoded_text}'")
print(f"Match: {text5 == decoded_text}")
print()
# Show some vocabulary statistics
print("VOCABULARY STATISTICS")
print("-" * 30)
byte_tokens = sum(1 for tid in vocab.keys() if tid < 256)
special_tokens = sum(1 for tid, token_bytes in vocab.items() if b'<|' in token_bytes)
merged_tokens = len(vocab) - byte_tokens - special_tokens
print(f"Byte tokens (0-255): {byte_tokens}")
print(f"Special tokens: {special_tokens}")
print(f"Merged tokens: {merged_tokens}")
print(f"Total vocabulary: {len(vocab)}")
# Show some example merged tokens
print(f"\nSample merged tokens:")
merged_token_ids = [tid for tid in sorted(vocab.keys()) if tid >= 257]
for i, token_id in enumerate(merged_token_ids[:10]):
token_bytes = vocab[token_id]
try:
decoded = token_bytes.decode('utf-8', errors='replace')
print(f" Token {token_id}: '{decoded}' ({len(token_bytes)} bytes)")
except:
print(f" Token {token_id}: {token_bytes} (binary)")
print("\n" + "=" * 60)
print("All examples completed successfully!")
BPE Tokenizer Sample Tests
Now run our complete test suite:
test_bpe_tokenizer()
Based on the training output from the TinyStories dataset, here are the testing results:
✓ Loaded tokenizer with 10000 vocab entries and 9743 merges
Example 1: Simple sentence
Original text: ‘Hello world! How are you today?’
Length: 31 characters
UTF-8 bytes: 31 bytes
Token count: 8 tokens
Compression ratio: 3.88x
Token breakdown:
- Token 1183: ‘Hello’ (5 bytes)
- Token 1569: ‘ world’ (6 bytes)
- Token 33: ‘!’ (1 bytes)
- Token 2683: ‘ How’ (4 bytes)
- Token 483: ‘ are’ (4 bytes)
- Token 349: ‘ you’ (4 bytes)
- Token 1709: ‘ today’ (6 bytes)
- Token 63: ‘?’ (1 bytes)
Decoded text: ‘Hello world! How are you today?’
Round-trip successful: True
Example 2: Text with special token
Original text: ‘Once upon a time<|endoftext|>The end.’
Length: 37 characters
UTF-8 bytes: 37 bytes
Token count: 8 tokens
Compression ratio: 4.62x
Token breakdown:
- Token 256: ‘
<|endoftext|>
’ (13 bytes) - Token 430: ‘Once’ (4 bytes)
- Token 439: ‘ upon’ (5 bytes)
- Token 259: ‘ a’ (2 bytes)
- Token 398: ‘ time’ (5 bytes)
- Token 410: ‘The’ (3 bytes)
- Token 870: ‘ end’ (4 bytes)
- Token 46: ‘.’ (1 bytes)
Decoded text: ‘<|endoftext|>Once upon a timeThe end.’
Round-trip successful: False
Example 3: Repeated words
Original text: ‘the the the cat cat sat sat on on the the mat mat’
Length: 49 characters
UTF-8 bytes: 49 bytes
Token count: 13 tokens
Compression ratio: 3.77x
Token breakdown:
- Token 7199: ‘the’ (3 bytes)
- Token 263: ‘ the’ (4 bytes)
- Token 263: ‘ the’ (4 bytes)
- Token 459: ‘ cat’ (4 bytes)
- Token 459: ‘ cat’ (4 bytes)
- Token 1091: ‘ sat’ (4 bytes)
- Token 1091: ‘ sat’ (4 bytes)
- Token 354: ‘ on’ (3 bytes)
- Token 354: ‘ on’ (3 bytes)
- Token 263: ‘ the’ (4 bytes)
- Token 263: ‘ the’ (4 bytes)
- Token 1492: ‘ mat’ (4 bytes)
- Token 1492: ‘ mat’ (4 bytes)
Decoded text: ‘the the the cat cat sat sat on on the the mat mat’
Round-trip successful: True
Example 4: Numbers and punctuation
Original text: ‘I have 123 apples, 456 oranges, and 789 bananas!’
Length: 48 characters
UTF-8 bytes: 48 bytes
Token count: 19 tokens
Compression ratio: 2.53x
Token breakdown:
- Token 73: ‘I’ (1 bytes)
- Token 499: ‘ have’ (5 bytes)
- Token 6314: ‘ 1’ (2 bytes)
- Token 50: ‘2’ (1 bytes)
- Token 51: ‘3’ (1 bytes)
- Token 1836: ‘ apples’ (7 bytes)
- Token 44: ‘,’ (1 bytes)
- Token 9079: ‘ 4’ (2 bytes)
- Token 53: ‘5’ (1 bytes)
- Token 54: ‘6’ (1 bytes)
- Token 5193: ‘ oranges’ (8 bytes)
- Token 44: ‘,’ (1 bytes)
- Token 267: ‘ and’ (4 bytes)
- Token 32: ‘ ‘ (1 bytes)
- Token 55: ‘7’ (1 bytes)
- Token 56: ‘8’ (1 bytes)
- Token 57: ‘9’ (1 bytes)
- Token 3898: ‘ bananas’ (8 bytes)
- Token 33: ‘!’ (1 bytes)
Decoded text: ‘I have 123 apples, 456 oranges, and 789 bananas!’
Round-trip successful: True
Example 5: Simple encode/decode
Original: ‘This is a test.’
Token IDs: [1531, 431, 259, 2569, 46]
Decoded: ‘This is a test.’
Match: True
Vocabulary Statistics
Byte tokens (0-255): 256
Special tokens: 1
Merged tokens: 9743
Total vocabulary: 10000
Sample merged tokens:
- Token 257: ‘ t’ (2 bytes)
- Token 258: ‘he’ (2 bytes)
- Token 259: ‘ a’ (2 bytes)
- Token 260: ‘ s’ (2 bytes)
- Token 261: ‘ w’ (2 bytes)
- Token 262: ‘nd’ (2 bytes)
- Token 263: ‘ the’ (4 bytes)
- Token 264: ‘ed’ (2 bytes)
- Token 265: ‘ b’ (2 bytes)
- Token 266: ‘ to’ (3 bytes)
All examples completed successfully!