Building a BPE Tokenizer from Scratch: Train the Tokenizer using TinyStories Dataset

Ever wondered how modern language models like GPT break down text into tokens? In this note, I will share how to build a Byte Pair Encoding (BPE) tokenizer from scratch and train it on the TinyStories Dataset. We will see how BPE achieves impressive compression ratios.

What is BPE Tokenization?

Byte Pair Encoding (BPE) is a compression algorithm that’s become the backbone of modern tokenization. Here’s how it works:

  1. Start with bytes: Every character becomes its byte representation (0-255)
  2. Find frequent pairs: Look for the most common pair of adjacent tokens
  3. Merge and repeat: Replace the most frequent pair with a new token, then repeat

A Simple Example

Let’s say we have the word “hello” appearing many times in our text:

  • Initially: h-e-l-l-o (5 tokens)
  • If “l-l” is the most frequent pair, merge it: h-e-ll-o (4 tokens)
  • If “e-ll” becomes frequent, merge it: h-ell-o (3 tokens)

This process creates a vocabulary that efficiently represents common patterns in your text. Check out my previous post for a brief introduction.

The TinyStories Dataset

We’ll train our tokenizer on TinyStories, a fascinating dataset of short stories written using only words that 3-4 year olds typically understand. These stories were generated by GPT-3.5 and GPT-4, making them perfect for experimenting with tokenization.

Downloading the Data

First, let’s download the TinyStories froom Huggingface:

!mkdir -p data
!cd data

!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt

!wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
!gunzip owt_train.txt.gz
!wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
!gunzip owt_valid.txt.gz

!cd ..

Challenge: Parallelizing Pre-tokenization

The TinyStories dataset is big (over 2GB), which presents a challenge for tokenizer training. We need to:

  1. Process the file in parallel for speed
  2. Ensure we don’t split tokens incorrectly at chunk boundaries

Solution: Smart Chunking with Special Tokens

Our solution uses special tokens (like <|endoftext|>) as natural boundaries for splitting the file.

Simple Example: Let’s say we have a text file containing: “Hello<SPLIT>World<SPLIT>How<SPLIT>Are<SPLIT>You”, special split token is <SPLIT>, and we want to divide the text into 3 chunks.

Here’s one implementation for intelligent file chunking as shared in the cs336 lecture notes:

import os
from typing import BinaryIO

def find_chunk_boundaries( file: BinaryIO,
      desired_num_chunks: int,
      split_special_token: bytes)->list[int]:
  """
  Chunk the file into parts that can be counted independently.
  May return fewer chunks if the boundaries end up overlapping.
  """
  assert isinstance(split_special_token, bytes),(
      "Must represent special token as a bytestring"
  )

  # Get total file size in bytes
  file.seek(0, os.SEEK_END)
  file_size = file.tell()
  file.seek(0)

  chunk_size = file_size // desired_num_chunks

  # Initial guesses for chunk boundary locations, uniformly spaced
  # Chunks start on previous index, don't include last index
  chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
  print(f"Initial guess of the chunk boundaries: {chunk_boundaries}")
  chunk_boundaries[-1] = file_size

  mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

  for bi in range(1, len(chunk_boundaries) - 1):
      initial_position = chunk_boundaries[bi]
      file.seek(initial_position)  # Start at boundary guess
      while True:
          mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

          # If EOF, this boundary should be at the end of the file
          if mini_chunk == b"":
              chunk_boundaries[bi] = file_size
              break

          # Find the special token in the mini chunk
          found_at = mini_chunk.find(split_special_token)
          if found_at != -1:
              chunk_boundaries[bi] = initial_position + found_at
              break
          initial_position += mini_chunk_size

  # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
  return sorted(set(chunk_boundaries))

Testing Our Chunking Algorithm

Let’s see how this works with a concrete example:

import io
def demonstrate_chunk_boundaries():
    """Demonstrate how to use find_chunk_boundaries with a practical example."""

    # Create sample data - our example text
    sample_text = "Hello<SPLIT>World<SPLIT>How<SPLIT>Are<SPLIT>You"
    sample_bytes = sample_text.encode('utf-8')

    print("=== Original Data ===")
    print(f"Text: {sample_text}")
    print(f"Bytes: {sample_bytes}")
    print(f"Total size: {len(sample_bytes)} bytes")
    print()

    # Create a file-like object from our sample data
    file_obj = io.BytesIO(sample_bytes)

    # Define our split token
    split_token = b"<SPLIT>"
    desired_chunks = 3

    print("=== Finding Chunk Boundaries ===")
    print(f"Desired number of chunks: {desired_chunks}")
    print(f"Split token: {split_token}")
    print()

    # Find the chunk boundaries
    boundaries = find_chunk_boundaries(file_obj, desired_chunks, split_token)

    print(f"Final boundaries: {boundaries}")
    print(f"Number of chunks created: {len(boundaries) - 1}")
    print()

    # Demonstrate how to use the boundaries to read chunks
    print("=== Reading Chunks ===")
    file_obj.seek(0)  # Reset file pointer

    for i in range(len(boundaries) - 1):
        start_pos = boundaries[i]
        end_pos = boundaries[i + 1]
        chunk_size = end_pos - start_pos

        # Read the chunk
        file_obj.seek(start_pos)
        chunk_data = file_obj.read(chunk_size)
        chunk_text = chunk_data.decode('utf-8')

        print(f"Chunk {i + 1}:")
        print(f"  Position: bytes {start_pos}-{end_pos-1}")
        print(f"  Size: {chunk_size} bytes")
        print(f"  Content: '{chunk_text}'")
        print(f"  Raw bytes: {chunk_data}")
        print()

Running this demonstration:

demonstrate_chunk_boundaries()

Output:

=== Original Data ===
Text: Hello<SPLIT>World<SPLIT>How<SPLIT>Are<SPLIT>You
Bytes: b'Hello<SPLIT>World<SPLIT>How<SPLIT>Are<SPLIT>You'
Total size: 47 bytes

=== Finding Chunk Boundaries ===
Desired number of chunks: 3
Split token: b'<SPLIT>'

Initial guess of the chunk boundaries: [0, 15, 30, 45]
Final boundaries: [0, 17, 37, 47]
Number of chunks created: 3

=== Reading Chunks ===
Chunk 1:
  Position: bytes 0-16
  Size: 17 bytes
  Content: 'Hello<SPLIT>World'
  Raw bytes: b'Hello<SPLIT>World'

Chunk 2:
  Position: bytes 17-36
  Size: 20 bytes
  Content: '<SPLIT>How<SPLIT>Are'
  Raw bytes: b'<SPLIT>How<SPLIT>Are'

Chunk 3:
  Position: bytes 37-46
  Size: 10 bytes
  Content: '<SPLIT>You'
  Raw bytes: b'<SPLIT>You'

Notice how the algorithm automatically adjusted the boundaries to align with <SPLIT> tokens, ensuring clean chunk separation.

BPE Training Implementation

Now implement the core BPE training algorithm. The implementation shared here handles parallel processing, special tokens, and efficient pair counting.

Core Training Function

Here’s is my complete BPE training implementation:

import re
import os
import multiprocessing as mp
from collections import defaultdict, Counter
from typing import Dict, List, Tuple, BinaryIO

# Simplified GPT-2-style regex pattern for pre-tokenization (using standard re module)
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?[a-zA-ZÀ-ÿ]+| ?[0-9]+| ?[^\s\w]+|\s+(?!\S)|\s+"""

def process_chunk(args):
    """Process a chunk of the file and return word counts."""
    start, end, input_path, special_tokens = args
    word_counts = defaultdict(int)

    with open(input_path, 'rb') as f:
        f.seek(start)
        chunk = f.read(end - start).decode('utf-8', errors='ignore')

        # Split on special tokens to prevent merging across boundaries
        if special_tokens:
            pattern = '|'.join(re.escape(token) for token in special_tokens)
            text_segments = re.split(f'({pattern})', chunk)
        else:
            text_segments = [chunk]

        for segment in text_segments:
            if segment in special_tokens:
                continue  # Skip special tokens during counting

            # Apply GPT-2 regex pattern
            for match in re.finditer(GPT2_SPLIT_PATTERN, segment):
                token_text = match.group()
                token_bytes = tuple(token_text.encode('utf-8'))
                word_counts[token_bytes] += 1

    return word_counts


def train_bpe_tokenizer(input_path: str, vocab_size: int, special_tokens: list[str], verbose: bool = True) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """
    Train a byte-level Byte Pair Encoding (BPE) tokenizer from a text file.

    Args:
        input_path: Path to the input text file containing training data
        vocab_size: Maximum size of the final vocabulary (includes initial bytes + special tokens + merges)
        special_tokens: List of special token strings to include in vocabulary
        verbose: Whether to print training progress information

    Returns:
        vocab: Complete tokenizer vocabulary mapping token IDs to byte sequences
        merges: Ordered list of BPE merge operations performed during training
    """
    import time

    # Initialize vocabulary with bytes 0-255
    vocab = {i: bytes([i]) for i in range(256)}
    next_id = 256

    # Add special tokens to vocabulary
    for token in special_tokens:
        token_bytes = token.encode('utf-8')
        vocab[next_id] = token_bytes
        next_id += 1

    if verbose:
        print("Step 1: Setting up parallel processing...")

    # Get chunk boundaries for multiprocessing
    num_processes = mp.cpu_count()
    if verbose:
        print(f"Using {num_processes} processes for parallel tokenization")

    with open(input_path, 'rb') as f:
        if special_tokens:
            # Use first special token for chunking boundaries
            if verbose:
                print(f"Finding chunk boundaries aligned with special token: {special_tokens[0]}")
            boundaries = find_chunk_boundaries(f, num_processes, special_tokens[0].encode('utf-8'))
        else:
            # Use simple chunking without special token alignment
            f.seek(0, os.SEEK_END)
            file_size = f.tell()
            chunk_size = file_size // num_processes
            boundaries = [i * chunk_size for i in range(num_processes + 1)]
            boundaries[-1] = file_size
            if verbose:
                print(f"File size: {file_size:,} bytes, chunk size: {chunk_size:,} bytes")

    if verbose:
        print(f"Created {len(boundaries)-1} chunks for processing")
        print("\nStep 2: Pre-tokenizing text corpus...")

    # Process chunks in parallel
    chunk_args = []
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        chunk_args.append((start, end, input_path, special_tokens))

    start_time = time.time()
    with mp.Pool(processes=num_processes) as pool:
        chunk_results = pool.map(process_chunk, chunk_args)
    tokenization_time = time.time() - start_time

    if verbose:
        print(f"Pre-tokenization completed in {tokenization_time:.2f} seconds")

    # Merge results from all chunks
    word_counts = defaultdict(int)
    total_tokens = 0
    for chunk_result in chunk_results:
        for word, count in chunk_result.items():
            word_counts[word] += count
            total_tokens += count

    if verbose:
        print(f"Found {len(word_counts):,} unique word types")
        print(f"Total token count: {total_tokens:,}")
        print(f"Most common words:")
        sorted_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)
        for i, (word_bytes, count) in enumerate(sorted_words[:5]):
            try:
                word_str = bytes(word_bytes).decode('utf-8', errors='replace')
                print(f"  {i+1}. '{word_str}' -> {count:,} times")
            except:
                print(f"  {i+1}. {word_bytes} -> {count:,} times")

    # Convert to working format for BPE (list of byte values)
    word_freq = {}
    for word_bytes, freq in word_counts.items():
        word_tokens = list(word_bytes)  # Convert to list of ints
        word_freq[tuple(word_tokens)] = freq

    merges = []
    pair_index = {}  # Efficient indexing for pair counting

    def update_pair_index(word_freq, pair_index):
        """Update the pair index for efficient counting."""
        pair_index.clear()
        for word, freq in word_freq.items():
            for i in range(len(word) - 1):
                pair = (word[i], word[i + 1])
                if pair not in pair_index:
                    pair_index[pair] = []
                pair_index[pair].append((word, i, freq))

    def count_pairs(pair_index):
        """Count pair frequencies efficiently using the index."""
        pair_counts = defaultdict(int)
        for pair, occurrences in pair_index.items():
            total_count = sum(freq for _, _, freq in occurrences)
            pair_counts[pair] = total_count
        return pair_counts

    # BPE training loop
    target_merges = vocab_size - len(vocab)

    if verbose:
        print(f"\nStep 3: Training BPE with {target_merges:,} merges...")
        print(f"Initial vocabulary size: {len(vocab)} (256 bytes + {len(special_tokens)} special tokens)")
        print("=" * 60)

    bpe_start_time = time.time()

    for merge_num in range(target_merges):
        merge_step_start = time.time()

        # Update pair index
        update_pair_index(word_freq, pair_index)

        # Count pairs efficiently
        pair_counts = count_pairs(pair_index)

        if not pair_counts:
            if verbose:
                print(f"No more pairs to merge at step {merge_num + 1}")
            break

        # Find most frequent pair (with lexicographic tiebreaking)
        best_pair = max(pair_counts.items(), key=lambda x: (x[1], x[0]))[0]
        best_count = pair_counts[best_pair]

        # Create new token for merge
        new_token_id = next_id
        next_id += 1

        # Get the byte sequences for the tokens being merged
        left_bytes = vocab[best_pair[0]]
        right_bytes = vocab[best_pair[1]]

        # Record merge as byte sequences
        merges.append((left_bytes, right_bytes))

        # Update vocabulary - merge the two byte sequences
        vocab[new_token_id] = left_bytes + right_bytes

        # Update word frequencies by applying merge
        new_word_freq = {}

        for word, freq in word_freq.items():
            new_word = []
            i = 0
            while i < len(word):
                if (i < len(word) - 1 and
                    word[i] == best_pair[0] and
                    word[i + 1] == best_pair[1]):
                    new_word.append(new_token_id)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1

            new_word_tuple = tuple(new_word)
            if new_word_tuple in new_word_freq:
                new_word_freq[new_word_tuple] += freq
            else:
                new_word_freq[new_word_tuple] = freq

        word_freq = new_word_freq
        merge_step_time = time.time() - merge_step_start

        # Progress logging
        if verbose:
            if (merge_num + 1) % 100 == 0 or merge_num < 10 or (merge_num + 1) % 1000 == 0:
                try:
                    left_str = left_bytes.decode('utf-8', errors='replace')
                    right_str = right_bytes.decode('utf-8', errors='replace')
                    merged_str = (left_bytes + right_bytes).decode('utf-8', errors='replace')
                    print(f"Merge {merge_num + 1:4d}/{target_merges}: "
                          f"'{left_str}' + '{right_str}' -> '{merged_str}' "
                          f"(freq: {best_count:,}, time: {merge_step_time:.3f}s)")
                except:
                    print(f"Merge {merge_num + 1:4d}/{target_merges}: "
                          f"{left_bytes} + {right_bytes} -> {left_bytes + right_bytes} "
                          f"(freq: {best_count:,}, time: {merge_step_time:.3f}s)")

    bpe_time = time.time() - bpe_start_time

    if verbose:
        print("=" * 60)
        print(f"BPE training completed in {bpe_time:.2f} seconds")
        print(f"Final vocabulary size: {len(vocab)}")
        print(f"Total merges performed: {len(merges)}")

        # Show compression statistics
        if word_counts:
            original_tokens = sum(len(bytes(word_bytes)) for word_bytes, count in word_counts.items() for _ in range(count))
            compressed_tokens = sum(len(word) for word, count in word_freq.items() for _ in range(count))
            compression_ratio = original_tokens / compressed_tokens if compressed_tokens > 0 else 1.0
            print(f"Compression ratio: {compression_ratio:.2f}x (from {original_tokens:,} to {compressed_tokens:,} tokens)")

    return vocab, merges


def save_tokenizer(vocab: dict[int, bytes], merges: list[tuple[bytes, bytes]],
                  vocab_path: str, merges_path: str):
    """Save vocabulary and merges to disk files."""
    import json
    import pickle

    # Save vocabulary
    with open(vocab_path, 'wb') as f:
        pickle.dump(vocab, f)

    # Save merges
    with open(merges_path, 'wb') as f:
        pickle.dump(merges, f)


def load_tokenizer(vocab_path: str, merges_path: str) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Load vocabulary and merges from disk files."""
    import pickle

    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    with open(merges_path, 'rb') as f:
        merges = pickle.load(f)

    return vocab, merges

Training on TinyStories Dataset

Now let’s use our implementation to train a tokenizer on the TinyStories dataset. Here is one training function to demonstrate all the steps:

import time
import os

def train_bpe_tokentizer_via_dataset(input_path: str):
    print("=" * 80)
    print("BPE TOKENIZER TRAINING ON TINYSTORIES DATASET")
    print("=" * 80)

    # Configuration
    vocab_size = 10000
    special_tokens = ["<|endoftext|>"]

    # Check if input file exists
    if not os.path.exists(input_path):
        print(f"Error: Input file '{input_path}' not found!")
        print("Please ensure the TinyStories dataset is in the data/ directory.")
        return

    # Display configuration
    file_size = os.path.getsize(input_path)
    print(f"Configuration:")
    print(f"  Input file: {input_path}")
    print(f"  File size: {file_size:,} bytes ({file_size / 1024 / 1024:.1f} MB)")
    print(f"  Target vocabulary size: {vocab_size:,}")
    print(f"  Special tokens: {special_tokens}")
    print(f"  Verbose logging: Enabled")
    print()

    # Train the tokenizer with verbose output
    overall_start_time = time.time()
    vocab, merges = train_bpe_tokenizer(
        input_path=input_path,
        vocab_size=vocab_size,
        special_tokens=special_tokens,
        verbose=True  # Enable detailed logging
    )
    overall_end_time = time.time()

    print("\n" + "=" * 80)
    print("TRAINING SUMMARY")
    print("=" * 80)
    print(f"Total training time: {overall_end_time - overall_start_time:.2f} seconds")
    print(f"Final vocabulary size: {len(vocab):,}")
    print(f"Number of merges performed: {len(merges):,}")
    print(f"Actual vocab size vs target: {len(vocab)} / {vocab_size}")

    # Save the tokenizer
    vocab_path = "tinystories_vocab.pkl"
    merges_path = "tinystories_merges.pkl"

    print(f"\nSaving tokenizer to disk...")
    save_tokenizer(vocab, merges, vocab_path, merges_path)
    print(f"  ✓ Vocabulary saved to: {vocab_path}")
    print(f"  ✓ Merges saved to: {merges_path}")

    # Detailed vocabulary analysis
    print("\n" + "=" * 80)
    print("VOCABULARY ANALYSIS")
    print("=" * 80)

    # Count different types of tokens
    byte_tokens = sum(1 for token_id in vocab.keys() if token_id < 256)
    special_token_count = len(special_tokens)
    merged_tokens = len(vocab) - byte_tokens - special_token_count

    print(f"Token type breakdown:")
    print(f"  Byte tokens (0-255): {byte_tokens}")
    print(f"  Special tokens: {special_token_count}")
    print(f"  Merged tokens: {merged_tokens}")
    print(f"  Total: {len(vocab)}")

    # Show some vocabulary examples
    print(f"\nByte tokens (first 10):")
    for i in range(10):
        if i in vocab:
            char = vocab[i].decode('utf-8', errors='replace')
            if char.isprintable() and char != ' ':
                print(f"  Token {i:3d}: {vocab[i]} -> '{char}'")
            else:
                print(f"  Token {i:3d}: {vocab[i]} -> {repr(char)}")

    print(f"\nSpecial tokens:")
    for token_str in special_tokens:
        token_bytes = token_str.encode('utf-8')
        for token_id, vocab_bytes in vocab.items():
            if vocab_bytes == token_bytes:
                print(f"  Token {token_id:3d}: {vocab_bytes} -> '{token_str}'")
                break

    print(f"\nMost recently merged tokens (last 10):")
    merged_token_ids = [tid for tid in sorted(vocab.keys()) if tid >= 256 + len(special_tokens)]
    for token_id in merged_token_ids[-10:]:
        try:
            decoded = vocab[token_id].decode('utf-8', errors='replace')
            print(f"  Token {token_id:4d}: {vocab[token_id]} -> '{decoded}'")
        except:
            print(f"  Token {token_id:4d}: {vocab[token_id]} -> (non-UTF8)")

    print(f"\nFirst 10 merge operations:")
    for i, (left, right) in enumerate(merges[:10]):
        try:
            left_str = left.decode('utf-8', errors='replace')
            right_str = right.decode('utf-8', errors='replace')
            merged_str = (left + right).decode('utf-8', errors='replace')
            print(f"  Merge {i+1:2d}: '{left_str}' + '{right_str}' -> '{merged_str}'")
        except:
            print(f"  Merge {i+1:2d}: {left} + {right} -> (binary)")

    print(f"\nLast 10 merge operations:")
    for i, (left, right) in enumerate(merges[-10:], len(merges) - 9):
        try:
            left_str = left.decode('utf-8', errors='replace')
            right_str = right.decode('utf-8', errors='replace')
            merged_str = (left + right).decode('utf-8', errors='replace')
            print(f"  Merge {i:2d}: '{left_str}' + '{right_str}' -> '{merged_str}'")
        except:
            print(f"  Merge {i:2d}: {left} + {right} -> (binary)")

    # Show file sizes
    vocab_size_bytes = os.path.getsize(vocab_path)
    merges_size_bytes = os.path.getsize(merges_path)
    print(f"\nOutput file sizes:")
    print(f"  Vocabulary file: {vocab_size_bytes:,} bytes ({vocab_size_bytes / 1024:.1f} KB)")
    print(f"  Merges file: {merges_size_bytes:,} bytes ({merges_size_bytes / 1024:.1f} KB)")
    print(f"  Total: {vocab_size_bytes + merges_size_bytes:,} bytes ({(vocab_size_bytes + merges_size_bytes) / 1024:.1f} KB)")

    print("\n" + "=" * 80)
    print("TRAINING COMPLETED SUCCESSFULLY!")
    print("=" * 80)
    print("You can now use the trained tokenizer for encoding/decoding text.")
    print(f"Load with: vocab, merges = load_tokenizer('{vocab_path}', '{merges_path}')")

To run the training, one can try for example:

train_bpe_tokentizer_via_dataset(input_path="data/TinyStoriesV2-GPT4-train.txt")

And it will output the following info from the training process:

================================================================================
BPE TOKENIZER TRAINING ON TINYSTORIES DATASET
================================================================================
Configuration:
  Input file: /content/TinyStoriesV2-GPT4-train.txt
  File size: 2,227,753,162 bytes (2124.6 MB)
  Target vocabulary size: 10,000
  Special tokens: ['<|endoftext|>']
  Verbose logging: Enabled

Step 1: Setting up parallel processing...
Using 12 processes for parallel tokenization
Finding chunk boundaries aligned with special token: <|endoftext|>
Initial guess of the chunk boundaries: [0, 185646096, 371292192, 556938288, 742584384, 928230480, 1113876576, 1299522672, 1485168768, 1670814864, 1856460960, 2042107056, 2227753152]
Created 12 chunks for processing

Step 2: Pre-tokenizing text corpus...
Pre-tokenization completed in 66.52 seconds
Found 59,904 unique word types
Total token count: 536,592,162
Most common words:
  1. '.' -> 41,764,519 times
  2. ',' -> 23,284,331 times
  3. ' the' -> 20,828,576 times
  4. ' and' -> 19,475,966 times
  5. ' a' -> 15,063,529 times

Step 3: Training BPE with 9,743 merges...
Initial vocabulary size: 257 (256 bytes + 1 special tokens)
============================================================
Merge    1/9743: ' ' + 't' -> ' t' (freq: 63,482,199, time: 0.273s)
Merge    2/9743: 'h' + 'e' -> 'he' (freq: 63,341,860, time: 0.318s)
Merge    3/9743: ' ' + 'a' -> ' a' (freq: 47,465,635, time: 0.340s)
Merge    4/9743: ' ' + 's' -> ' s' (freq: 32,362,158, time: 0.340s)
Merge    5/9743: ' ' + 'w' -> ' w' (freq: 31,485,643, time: 0.327s)
Merge    6/9743: 'n' + 'd' -> 'nd' (freq: 28,922,386, time: 0.332s)
Merge    7/9743: ' t' + 'he' -> ' the' (freq: 28,915,024, time: 0.320s)
Merge    8/9743: 'e' + 'd' -> 'ed' (freq: 24,836,456, time: 0.317s)
Merge    9/9743: ' ' + 'b' -> ' b' (freq: 22,147,488, time: 0.326s)
Merge   10/9743: ' t' + 'o' -> ' to' (freq: 20,892,273, time: 0.322s)
Merge  100/9743: ' ha' + 'pp' -> ' happ' (freq: 3,147,884, time: 0.251s)
Merge  200/9743: ' s' + 'e' -> ' se' (freq: 1,410,130, time: 0.343s)
Merge  300/9743: ' s' + 'omet' -> ' somet' (freq: 790,510, time: 0.245s)
Merge  400/9743: ' g' + 'ot' -> ' got' (freq: 524,776, time: 0.338s)
Merge  500/9743: ' e' + 'ach' -> ' each' (freq: 369,637, time: 0.321s)
Merge  600/9743: 'l' + 'f' -> 'lf' (freq: 279,566, time: 0.230s)
Merge  700/9743: ' wal' + 'k' -> ' walk' (freq: 221,114, time: 0.237s)
Merge  800/9743: ' do' + 'll' -> ' doll' (freq: 177,602, time: 0.324s)
Merge  900/9743: ' ' + 'G' -> ' G' (freq: 147,699, time: 0.214s)
Merge 1000/9743: 'ec' + 't' -> 'ect' (freq: 127,288, time: 0.233s)
Merge 1100/9743: ' l' + 'ight' -> ' light' (freq: 108,006, time: 0.208s)
Merge 1200/9743: ' d' + 'in' -> ' din' (freq: 92,211, time: 0.225s)
Merge 1300/9743: ' picture' + 's' -> ' pictures' (freq: 80,416, time: 0.318s)
Merge 1400/9743: 'itt' + 'en' -> 'itten' (freq: 68,466, time: 0.235s)
Merge 1500/9743: 'A' + 'my' -> 'Amy' (freq: 59,829, time: 0.306s)
Merge 1600/9743: ' tal' + 'king' -> ' talking' (freq: 53,781, time: 0.330s)
Merge 1700/9743: 'b' + 'all' -> 'ball' (freq: 48,005, time: 0.309s)
Merge 1800/9743: ' k' + 'iss' -> ' kiss' (freq: 43,477, time: 0.318s)
...
Merge 8000/9743: ' mom' + 'mies' -> ' mommies' (freq: 879, time: 0.205s)
Merge 8100/9743: ' cryst' + 'als' -> ' crystals' (freq: 840, time: 0.299s)
Merge 8200/9743: ' playd' + 'ate' -> ' playdate' (freq: 809, time: 0.283s)
Merge 8300/9743: ' support' + 'ing' -> ' supporting' (freq: 778, time: 0.200s)
Merge 8400/9743: ' activ' + 'ity' -> ' activity' (freq: 747, time: 0.300s)
Merge 8500/9743: 'L' + 'izzy' -> 'Lizzy' (freq: 716, time: 0.284s)
Merge 8600/9743: 'er' + 'ing' -> 'ering' (freq: 691, time: 0.311s)
Merge 8700/9743: ' tid' + 'ied' -> ' tidied' (freq: 660, time: 0.308s)
Merge 8800/9743: 'f' + 'lowers' -> 'flowers' (freq: 633, time: 0.295s)
Merge 8900/9743: ' Gra' + 'nd' -> ' Grand' (freq: 609, time: 0.299s)
Merge 9000/9743: ' frustr' + 'ation' -> ' frustration' (freq: 584, time: 0.301s)
Merge 9100/9743: 'amil' + 'iar' -> 'amiliar' (freq: 561, time: 0.205s)
Merge 9200/9743: ' P' + 'retty' -> ' Pretty' (freq: 542, time: 0.310s)
Merge 9300/9743: ' sal' + 'on' -> ' salon' (freq: 521, time: 0.292s)
Merge 9400/9743: ' p' + 'ounced' -> ' pounced' (freq: 502, time: 0.196s)
Merge 9500/9743: ' pops' + 'ic' -> ' popsic' (freq: 485, time: 0.185s)
Merge 9600/9743: ' pain' + 'ful' -> ' painful' (freq: 469, time: 0.298s)
Merge 9700/9743: 'solut' + 'ely' -> 'solutely' (freq: 454, time: 0.308s)
============================================================
BPE training completed in 2731.72 seconds
Final vocabulary size: 10000
Total merges performed: 9743
Compression ratio: 4.07x (from 2,192,422,648 to 538,511,097 tokens)

================================================================================
TRAINING SUMMARY
================================================================================
Total training time: 2898.45 seconds
Final vocabulary size: 10,000
Number of merges performed: 9,743
Actual vocab size vs target: 10000 / 10000

Saving tokenizer to disk...
  ✓ Vocabulary saved to: tinystories_vocab.pkl
  ✓ Merges saved to: tinystories_merges.pkl

================================================================================
VOCABULARY ANALYSIS
================================================================================
Token type breakdown:
  Byte tokens (0-255): 256
  Special tokens: 1
  Merged tokens: 9743
  Total: 10000

Byte tokens (first 10):
  Token   0: b'\x00' -> '\x00'
  Token   1: b'\x01' -> '\x01'
  Token   2: b'\x02' -> '\x02'
  Token   3: b'\x03' -> '\x03'
  Token   4: b'\x04' -> '\x04'
  Token   5: b'\x05' -> '\x05'
  Token   6: b'\x06' -> '\x06'
  Token   7: b'\x07' -> '\x07'
  Token   8: b'\x08' -> '\x08'
  Token   9: b'\t' -> '\t'

Special tokens:
  Token 256: b'<|endoftext|>' -> '<|endoftext|>'

Most recently merged tokens (last 10):
  Token 9990: b' improving' -> ' improving'
  Token 9991: b' nicest' -> ' nicest'
  Token 9992: b' whiskers' -> ' whiskers'
  Token 9993: b' booth' -> ' booth'
  Token 9994: b' Land' -> ' Land'
  Token 9995: b'Rocky' -> 'Rocky'
  Token 9996: b' meadows' -> ' meadows'
  Token 9997: b' Starry' -> ' Starry'
  Token 9998: b' imaginary' -> ' imaginary'
  Token 9999: b' bold' -> ' bold'

First 10 merge operations:
  Merge  1: ' ' + 't' -> ' t'
  Merge  2: 'h' + 'e' -> 'he'
  Merge  3: ' ' + 'a' -> ' a'
  Merge  4: ' ' + 's' -> ' s'
  Merge  5: ' ' + 'w' -> ' w'
  Merge  6: 'n' + 'd' -> 'nd'
  Merge  7: ' t' + 'he' -> ' the'
  Merge  8: 'e' + 'd' -> 'ed'
  Merge  9: ' ' + 'b' -> ' b'
  Merge 10: ' t' + 'o' -> ' to'

Last 10 merge operations:
  Merge 9734: ' impro' + 'ving' -> ' improving'
  Merge 9735: ' nice' + 'st' -> ' nicest'
  Merge 9736: ' wh' + 'iskers' -> ' whiskers'
  Merge 9737: ' bo' + 'oth' -> ' booth'
  Merge 9738: ' L' + 'and' -> ' Land'
  Merge 9739: 'Rock' + 'y' -> 'Rocky'
  Merge 9740: ' meadow' + 's' -> ' meadows'
  Merge 9741: ' St' + 'arry' -> ' Starry'
  Merge 9742: ' imag' + 'inary' -> ' imaginary'
  Merge 9743: ' bo' + 'ld' -> ' bold'

Output file sizes:
  Vocabulary file: 117,701 bytes (114.9 KB)
  Merges file: 109,714 bytes (107.1 KB)
  Total: 227,415 bytes (222.1 KB)

================================================================================
TRAINING COMPLETED SUCCESSFULLY!
================================================================================
You can now use the trained tokenizer for encoding/decoding text.
Load with: vocab, merges = load_tokenizer('tinystories_vocab.pkl', 'tinystories_merges.pkl')

Using the Trained Tokenizer

Once we have a trained tokenizer, we need a class to encode and decode text. Here’s one complete implementation:

class SimpleBPETokenizer:
    """Simple BPE tokenizer for encoding/decoding text."""

    def __init__(self, vocab, merges, special_tokens=None):
        self.vocab = vocab  # {token_id: bytes}
        self.merges = merges  # [(left_bytes, right_bytes), ...]
        self.special_tokens = special_tokens or ["<|endoftext|>"]

        # Create reverse mapping for decoding
        self.id_to_bytes = vocab
        self.bytes_to_id = {v: k for k, v in vocab.items()}

        # GPT-2 style regex pattern
        self.pattern = r"""'(?:[sdmt]|ll|ve|re)| ?[a-zA-ZÀ-ÿ]+| ?[0-9]+| ?[^\s\w]+|\s+(?!\S)|\s+"""

        # Build merge rules for faster encoding
        self.merge_rules = {}
        for i, (left_bytes, right_bytes) in enumerate(merges):
            # Find what tokens these bytes correspond to
            left_id = self.bytes_to_id.get(left_bytes)
            right_id = self.bytes_to_id.get(right_bytes)
            merged_bytes = left_bytes + right_bytes
            merged_id = self.bytes_to_id.get(merged_bytes)

            if left_id is not None and right_id is not None and merged_id is not None:
                self.merge_rules[(left_id, right_id)] = merged_id

    def encode(self, text: str) -> list[int]:
        """Encode text to token IDs."""
        if not text:
            return []

        # Handle special tokens
        token_ids = []
        remaining_text = text

        # Split on special tokens first
        for special_token in self.special_tokens:
            if special_token in remaining_text:
                parts = remaining_text.split(special_token)
                new_parts = []
                for i, part in enumerate(parts):
                    if i > 0:
                        # Add special token
                        special_bytes = special_token.encode('utf-8')
                        special_id = self.bytes_to_id.get(special_bytes)
                        if special_id is not None:
                            token_ids.append(special_id)
                    if part:
                        new_parts.append(part)
                remaining_text = ''.join(new_parts)

        # Apply regex tokenization
        for match in re.finditer(self.pattern, remaining_text):
            word = match.group()
            word_tokens = self._encode_word(word)
            token_ids.extend(word_tokens)

        return token_ids

    def _encode_word(self, word: str) -> list[int]:
        """Encode a single word using BPE merges."""
        # Start with individual bytes
        word_bytes = word.encode('utf-8')
        tokens = []

        # Convert each byte to its token ID
        for byte_val in word_bytes:
            tokens.append(byte_val)  # Byte token IDs are 0-255

        # Apply BPE merges
        while len(tokens) > 1:
            # Find the best merge to apply
            best_merge = None
            best_pos = -1
            best_merge_priority = float('inf')

            for i in range(len(tokens) - 1):
                pair = (tokens[i], tokens[i + 1])
                if pair in self.merge_rules:
                    # Find merge priority (earlier merges have higher priority)
                    merged_bytes = self.id_to_bytes[tokens[i]] + self.id_to_bytes[tokens[i + 1]]
                    for j, (left_bytes, right_bytes) in enumerate(self.merges):
                        if left_bytes + right_bytes == merged_bytes:
                            if j < best_merge_priority:
                                best_merge = pair
                                best_pos = i
                                best_merge_priority = j
                            break

            if best_merge is None:
                break

            # Apply the best merge
            new_tokens = tokens[:best_pos]
            new_tokens.append(self.merge_rules[best_merge])
            new_tokens.extend(tokens[best_pos + 2:])
            tokens = new_tokens

        return tokens

    def decode(self, token_ids: list[int]) -> str:
        """Decode token IDs back to text."""
        if not token_ids:
            return ""

        # Convert token IDs to bytes
        byte_sequences = []
        for token_id in token_ids:
            if token_id in self.id_to_bytes:
                byte_sequences.append(self.id_to_bytes[token_id])
            else:
                # Handle unknown tokens
                byte_sequences.append(b'<UNK>')

        # Concatenate all bytes and decode
        all_bytes = b''.join(byte_sequences)
        try:
            return all_bytes.decode('utf-8', errors='replace')
        except:
            return all_bytes.decode('utf-8', errors='ignore')

    def tokenize_with_details(self, text: str):
        """Tokenize text and show detailed breakdown."""
        token_ids = self.encode(text)

        print(f"Original text: '{text}'")
        print(f"Length: {len(text)} characters")
        print(f"UTF-8 bytes: {len(text.encode('utf-8'))} bytes")
        print(f"Token count: {len(token_ids)} tokens")
        print(f"Compression ratio: {len(text.encode('utf-8')) / len(token_ids):.2f}x")
        print()

        print("Token breakdown:")
        for i, token_id in enumerate(token_ids):
            token_bytes = self.id_to_bytes[token_id]
            try:
                token_str = token_bytes.decode('utf-8', errors='replace')
                if token_str.isprintable():
                    print(f"  {i+1:2d}. Token {token_id:4d}: '{token_str}' ({len(token_bytes)} bytes)")
                else:
                    print(f"  {i+1:2d}. Token {token_id:4d}: {repr(token_str)} ({len(token_bytes)} bytes)")
            except:
                print(f"  {i+1:2d}. Token {token_id:4d}: {token_bytes} (binary)")

        # Verify round-trip
        decoded = self.decode(token_ids)
        print(f"\nDecoded text: '{decoded}'")
        print(f"Round-trip successful: {text == decoded}")

        return token_ids

Let us compose some simple test cases below:

def test_bpe_tokenizer():
    print("=" * 60)
    print("BPE TOKENIZER SAMPLE TESTS")
    print("=" * 60)

    # Load the trained tokenizer
    try:
        vocab, merges = load_tokenizer('tinystories_vocab.pkl', 'tinystories_merges.pkl')
        print(f"✓ Loaded tokenizer with {len(vocab)} vocab entries and {len(merges)} merges")
    except FileNotFoundError:
        print("Error: Tokenizer files not found!")
        print("Please run the training script first to create 'tinystories_vocab.pkl' and 'tinystories_merges.pkl'")
        return

    # Create tokenizer instance
    tokenizer = SimpleBPETokenizer(vocab, merges)
    print()

    # Example 1: Simple sentence
    print("EXAMPLE 1: Simple sentence")
    print("-" * 30)
    text1 = "Hello world! How are you today?"
    tokenizer.tokenize_with_details(text1)
    print()

    # Example 2: Text with special token
    print("EXAMPLE 2: Text with special token")
    print("-" * 30)
    text2 = "Once upon a time<|endoftext|>The end."
    tokenizer.tokenize_with_details(text2)
    print()

    # Example 3: Repeated words (should compress well)
    print("EXAMPLE 3: Repeated words")
    print("-" * 30)
    text3 = "the the the cat cat sat sat on on the the mat mat"
    tokenizer.tokenize_with_details(text3)
    print()

    # Example 4: Numbers and punctuation
    print("EXAMPLE 4: Numbers and punctuation")
    print("-" * 30)
    text4 = "I have 123 apples, 456 oranges, and 789 bananas!"
    tokenizer.tokenize_with_details(text4)
    print()

    # Example 5: Just encoding/decoding
    print("EXAMPLE 5: Simple encode/decode")
    print("-" * 30)
    text5 = "This is a test."
    token_ids = tokenizer.encode(text5)
    decoded_text = tokenizer.decode(token_ids)

    print(f"Original: '{text5}'")
    print(f"Token IDs: {token_ids}")
    print(f"Decoded: '{decoded_text}'")
    print(f"Match: {text5 == decoded_text}")
    print()

    # Show some vocabulary statistics
    print("VOCABULARY STATISTICS")
    print("-" * 30)
    byte_tokens = sum(1 for tid in vocab.keys() if tid < 256)
    special_tokens = sum(1 for tid, token_bytes in vocab.items() if b'<|' in token_bytes)
    merged_tokens = len(vocab) - byte_tokens - special_tokens

    print(f"Byte tokens (0-255): {byte_tokens}")
    print(f"Special tokens: {special_tokens}")
    print(f"Merged tokens: {merged_tokens}")
    print(f"Total vocabulary: {len(vocab)}")

    # Show some example merged tokens
    print(f"\nSample merged tokens:")
    merged_token_ids = [tid for tid in sorted(vocab.keys()) if tid >= 257]
    for i, token_id in enumerate(merged_token_ids[:10]):
        token_bytes = vocab[token_id]
        try:
            decoded = token_bytes.decode('utf-8', errors='replace')
            print(f"  Token {token_id}: '{decoded}' ({len(token_bytes)} bytes)")
        except:
            print(f"  Token {token_id}: {token_bytes} (binary)")

    print("\n" + "=" * 60)
    print("All examples completed successfully!")

BPE Tokenizer Sample Tests

Now run our complete test suite:

test_bpe_tokenizer()

Based on the training output from the TinyStories dataset, here are the testing results:

✓ Loaded tokenizer with 10000 vocab entries and 9743 merges

Example 1: Simple sentence

Original text: ‘Hello world! How are you today?’
Length: 31 characters
UTF-8 bytes: 31 bytes
Token count: 8 tokens
Compression ratio: 3.88x

Token breakdown:

  1. Token 1183: ‘Hello’ (5 bytes)
  2. Token 1569: ‘ world’ (6 bytes)
  3. Token 33: ‘!’ (1 bytes)
  4. Token 2683: ‘ How’ (4 bytes)
  5. Token 483: ‘ are’ (4 bytes)
  6. Token 349: ‘ you’ (4 bytes)
  7. Token 1709: ‘ today’ (6 bytes)
  8. Token 63: ‘?’ (1 bytes)

Decoded text: ‘Hello world! How are you today?’
Round-trip successful: True

Example 2: Text with special token

Original text: ‘Once upon a time<|endoftext|>The end.’
Length: 37 characters
UTF-8 bytes: 37 bytes
Token count: 8 tokens
Compression ratio: 4.62x

Token breakdown:

  1. Token 256: ‘<|endoftext|>’ (13 bytes)
  2. Token 430: ‘Once’ (4 bytes)
  3. Token 439: ‘ upon’ (5 bytes)
  4. Token 259: ‘ a’ (2 bytes)
  5. Token 398: ‘ time’ (5 bytes)
  6. Token 410: ‘The’ (3 bytes)
  7. Token 870: ‘ end’ (4 bytes)
  8. Token 46: ‘.’ (1 bytes)

Decoded text: ‘<|endoftext|>Once upon a timeThe end.’
Round-trip successful: False

Example 3: Repeated words

Original text: ‘the the the cat cat sat sat on on the the mat mat’
Length: 49 characters
UTF-8 bytes: 49 bytes
Token count: 13 tokens
Compression ratio: 3.77x

Token breakdown:

  1. Token 7199: ‘the’ (3 bytes)
  2. Token 263: ‘ the’ (4 bytes)
  3. Token 263: ‘ the’ (4 bytes)
  4. Token 459: ‘ cat’ (4 bytes)
  5. Token 459: ‘ cat’ (4 bytes)
  6. Token 1091: ‘ sat’ (4 bytes)
  7. Token 1091: ‘ sat’ (4 bytes)
  8. Token 354: ‘ on’ (3 bytes)
  9. Token 354: ‘ on’ (3 bytes)
  10. Token 263: ‘ the’ (4 bytes)
  11. Token 263: ‘ the’ (4 bytes)
  12. Token 1492: ‘ mat’ (4 bytes)
  13. Token 1492: ‘ mat’ (4 bytes)

Decoded text: ‘the the the cat cat sat sat on on the the mat mat’
Round-trip successful: True

Example 4: Numbers and punctuation

Original text: ‘I have 123 apples, 456 oranges, and 789 bananas!’
Length: 48 characters
UTF-8 bytes: 48 bytes
Token count: 19 tokens
Compression ratio: 2.53x

Token breakdown:

  1. Token 73: ‘I’ (1 bytes)
  2. Token 499: ‘ have’ (5 bytes)
  3. Token 6314: ‘ 1’ (2 bytes)
  4. Token 50: ‘2’ (1 bytes)
  5. Token 51: ‘3’ (1 bytes)
  6. Token 1836: ‘ apples’ (7 bytes)
  7. Token 44: ‘,’ (1 bytes)
  8. Token 9079: ‘ 4’ (2 bytes)
  9. Token 53: ‘5’ (1 bytes)
  10. Token 54: ‘6’ (1 bytes)
  11. Token 5193: ‘ oranges’ (8 bytes)
  12. Token 44: ‘,’ (1 bytes)
  13. Token 267: ‘ and’ (4 bytes)
  14. Token 32: ‘ ‘ (1 bytes)
  15. Token 55: ‘7’ (1 bytes)
  16. Token 56: ‘8’ (1 bytes)
  17. Token 57: ‘9’ (1 bytes)
  18. Token 3898: ‘ bananas’ (8 bytes)
  19. Token 33: ‘!’ (1 bytes)

Decoded text: ‘I have 123 apples, 456 oranges, and 789 bananas!’
Round-trip successful: True

Example 5: Simple encode/decode

Original: ‘This is a test.’
Token IDs: [1531, 431, 259, 2569, 46]
Decoded: ‘This is a test.’
Match: True

Vocabulary Statistics

Byte tokens (0-255): 256
Special tokens: 1
Merged tokens: 9743
Total vocabulary: 10000

Sample merged tokens:

  • Token 257: ‘ t’ (2 bytes)
  • Token 258: ‘he’ (2 bytes)
  • Token 259: ‘ a’ (2 bytes)
  • Token 260: ‘ s’ (2 bytes)
  • Token 261: ‘ w’ (2 bytes)
  • Token 262: ‘nd’ (2 bytes)
  • Token 263: ‘ the’ (4 bytes)
  • Token 264: ‘ed’ (2 bytes)
  • Token 265: ‘ b’ (2 bytes)
  • Token 266: ‘ to’ (3 bytes)

All examples completed successfully!