wk_

Micro GPT

May 25, 2025 · 12 min read
gpt transformer pytorch

Building a GPT from Scratch in PyTorch

Titled appropriately, this is the part two to my previous work with bigram models. As we discussed, bigram models have an obvious flaw in that they can only predict based on the current token.

This means that context, something we are all familiar with by now, is not possible. We talked briefly on the solution to this, and now it’s time to really see what it looks like

Micro GPT is a decoder-only transformer built from scratch in PyTorch that extends the bigram approach with self-attention, positional embeddings, and stacked transformer blocks. It’s a character-level model trained on OpenWebText data, small enough to train on a laptop but architecturally identical to the models behind ChatGPT 2.0.


From Bigram to Transformer

The bigram model was a single embedding table. An 81x81 matrix where row i contained the model’s learned prediction for what follows character i. That’s one lookup per prediction with no context and no memory.

A transformer replaces that lookup with a pipeline:

Input Tokens


Token Embedding + Position Embedding


┌──────────────────────────────────┐
│       Transformer Block (×N)     │
│  ┌────────────────────────────┐  │
│  │  Multi-Head Self-Attention │  │
│  │  + Residual + Layer Norm   │  │
│  ├────────────────────────────┤  │
│  │  Feed-Forward Network      │  │
│  │  + Residual + Layer Norm   │  │
│  └────────────────────────────┘  │
└──────────────────────────────────┘


Linear → Logits over Vocabulary

Every component in this pipeline solves a specific limitation of the bigram model. Positional embeddings give the model awareness of order. Self-attention lets it look at all previous characters. Feed-forward layers let it transform what it’s learned. Stacking blocks lets it build increasingly abstract representations.

Kind of like an open-note exam.

The training configuration:

ParameterValue
Embedding dimension384
Attention heads8
Transformer blocks8
Context length128 characters
Batch size64
Dropout0.2
Learning rate3e-4
Training iterations10,000

Memory-Mapped Chunks

The bigram model loaded the entire training text into memory as a single tensor. That works for The Wizard of Oz, but not for OpenWebText. The dataset is way too large to fit in RAM.

The solution is memory-mapped I/O. Instead of reading the whole file, the training script maps it into virtual memory and reads random chunks on demand:

def get_random_chunk(split):
    filename = "train_split.txt" if split == 'train' else "val_split.txt"
    with open(filename, 'rb') as f:
        with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
            file_size = len(mm)
            start_pos = random.randint(0, (file_size) - block_size*batch_size)

            mm.seek(start_pos)
            block = mm.read(block_size*batch_size-1)

            decoded_block = block.decode('utf-8', errors='ignore').replace('\r', '')
            data = torch.tensor(encode(decoded_block), dtype=torch.long)

    return data

mmap.mmap maps the file into the process’s address space without loading it. The OS pages data in and out as needed. random.randint picks a random starting position, and mm.read pulls a chunk just large enough for one batch of sequences.

The errors='ignore' on decode is important. Since we’re slicing into a file at arbitrary byte positions, we might land in the middle of a multi-byte UTF-8 character. Ignoring those partial characters avoids crashes at the cost of occasionally losing a character at chunk boundaries. For character-level training on gigabytes of text, this is a negligible loss. Niche but essential.

From each chunk, batches are sampled the same way as the bigram model:

def get_batch(split):
    data = get_random_chunk(split)
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

Each call reads a fresh random chunk from disk, samples batch_size random positions within it, and extracts input-target pairs. The target is always the input shifted by one position. Predict the next character at every position in the sequence, familiar?


Self-Attention

You probably read the last section and thought, similar enough right? Well not really.

Self-attention is the core mechanism that separates a transformer from a simple lookup table. It lets every token in a sequence compute a weighted sum over all previous tokens, deciding dynamically which parts of the context matter for the current prediction.

Each attention head learns three linear transformations:

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

Query: “what am I looking for?” Each token produces a query vector that represents what information it needs. Key: “what do I contain?” Each token produces a key vector that describes what information it offers. Value: “what do I actually give you?” Once attention decides which tokens to focus on, the value vectors are what get passed forward.

You’re at a networking event (sorry). You have a question in mind (query). Every person in the room is wearing a nametag describing their expertise (key). You scan the nametags, decide who’s relevant, walk over, and get their actual advice (value). Self-attention does this for every token in the sequence. Banger analogy, props to me.

The forward pass computes attention in four steps:

def forward(self, x):
    B,T,C = x.shape
    k = self.key(x)
    q = self.query(x)

    # 1. Compute attention scores
    wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5

    # 2. Mask future positions
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))

    # 3. Normalize to probabilities
    wei = F.softmax(wei, dim=-1)
    wei = self.dropout(wei)

    # 4. Weighted sum of values
    v = self.value(x)
    out = wei @ v
    return out

Bear with me here.

Step 1 computes raw attention scores with q @ k.transpose(-2,-1). This is a dot product between every query and every key, producing a (T, T) matrix where entry [i, j] measures how much token i should attend to token j. The * k.shape[-1]**-0.5 is scaled dot product attention. Without this scaling, the dot products grow with the dimension, pushing softmax into regions with vanishing gradients.

Step 2 is causal masking. The lower-triangular matrix tril zeros out the upper triangle by setting those positions to -inf. After softmax, -inf becomes 0. This ensures that position 5 can attend to positions 0-4 but never positions 6+. Without this, the model could cheat during training by looking at future characters.

It’s the equivalent of covering the answer key during an exam. The model has to predict the next character using only what came before, never what comes after.

Step 3 normalizes each row to a probability distribution and applies dropout for regularization.

Step 4 uses those attention weights to compute a weighted sum of value vectors. If token 5 attends strongly to token 2, it receives mostly token 2’s value — that’s how information flows backward through the sequence.


Multi-Head Attention

Sorry about that last section, this one is less painful.

A single attention head can only capture one pattern. Maybe it learns to attend to the previous vowel, or to the start of the current word. But language has many simultaneous relationships: syntax, semantics, spacing, capitalization. Multi-head attention runs several heads in parallel, each free to learn a different pattern:

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

With 8 heads and an embedding dimension of 384, each head operates in a 48-dimensional subspace (384 / 8 = 48). The heads run independently, their outputs are concatenated back to 384 dimensions, and a final linear projection fuses the information.

This division of labor is efficient. Instead of one head trying to capture everything in a 384-dimensional space, eight heads each specialize in a 48-dimensional slice. The projection layer combines their findings.


Feed-Forward Network

Attention allows our tokens to communicate. The feed-forward network is how each token independently processes the information it gathered:

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

The first linear layer expands the representation from 384 to 1536 dimensions — a 4x expansion that gives the network more room to transform. ReLU introduces non-linearity, allowing the network to model complex patterns that a purely linear layer couldn’t. The second linear layer compresses back to 384.

If attention is the “talking to other tokens” phase, feed-forward is the “thinking alone” phase. Each token takes what it learned from attending to other tokens, expands it into a larger workspace, processes it, and compresses it back down. Like writing your notes out on a big whiteboard to work through a problem, then summarizing your conclusions back onto an index card.

This expand-activate-compress pattern appears in nearly every transformer ever built. The expansion ratio is almost always 4x. It’s one of those architectural choices from the original “Attention Is All You Need” paper that has persisted through every subsequent generation of models.


The Transformer Block

A single transformer block combines attention and feed-forward processing with two crucial additions: residual connections and layer normalization:

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        y = self.sa(x)
        x = self.ln1(x + y)
        y = self.ffwd(x)
        x = self.ln2(x + y)
        return x

The residual connections (x + y) are the most important structural choice. Instead of replacing the input with the attention output, the block adds the attention output to the input. This means information can flow through the network unchanged if a block has nothing useful to add. It also solves the vanishing gradient problem. During backpropagation, gradients can flow directly through the addition, bypassing any blocks that might attenuate them.

Its like optional pit stops. The data can drive straight through (the + x addition preserves the original signal), or it can pull off, get processed by attention, and merge back in.

Layer normalization stabilizes the values at each stage. Without it, activations can grow or shrink as they pass through successive blocks, making training unstable. nn.LayerNorm normalizes across the embedding dimension, ensuring each position has zero mean and unit variance before entering the next operation.

The model stacks 8 of these blocks in sequence. Each block refines the representation where early blocks might learn character-level patterns and later blocks might capture word-level or phrase-level structure. The depth of the stack determines how abstract the model’s internal representations can become.


The Full Model

The GPTLanguageModel class ties everything together:

class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

        self.apply(self._init_weights)

Two embedding tables handle input. token_embedding_table converts each character index into a 384-dimensional vector, this is the learned representation of each character’s identity. position_embedding_table converts each position (0 through 127) into a 384-dimensional vector, this is the learned representation of each position’s meaning.

Adding these two embeddings means the model knows both what a character is and where it appears. The letter ‘a’ at position 0 and ‘a’ at position 50 are the same character but carry different positional context.

These two embeddings are added together in the forward pass:

def forward(self, index, targets=None):
    B, T = index.shape
    tok_emb = self.token_embedding_table(index)
    pos_emb = self.position_embedding_table(torch.arange(T, device=device))
    x = tok_emb + pos_emb
    x = self.blocks(x)
    x = self.ln_f(x)
    logits = self.lm_head(x)

The addition of token and position embeddings means each vector entering the transformer stack encodes both what a character is and where it sits. The 8 transformer blocks then process this combined representation. A final layer norm and linear projection map the output back to vocabulary-sized logits. One score per possible next character.

Weight Initialization

Before training begins, every parameter is initialized:

def _init_weights(self, module):
    if isinstance(module, nn.Linear):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

Small random weights (standard deviation 0.02) prevent any single neuron from dominating early in training. Biases start at zero. This initialization scheme comes from the GPT-2 paper and has become standard practice. It keeps the forward pass numerically stable before the optimizer has had a chance to find good values.

It’s a surprisingly important detail. Initialize too large and the network explodes (values overflow). Initialize too small and the signal dies (vanishing gradients). 0.02 is the Goldilocks zone that’s been empirically validated across thousands of models. One of those things that seems arbitrary until you try changing it.

Or so I’ve read, lol.


Generation

At inference time, the model generates text autoregressively. One character at a time, feeding each prediction back as input:

def generate(self, index, max_new_tokens):
    for _ in range(max_new_tokens):
        index_cond = index[:, -block_size:]
        logits, loss = self.forward(index_cond)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        index_next = torch.multinomial(probs, num_samples=1)
        index = torch.cat((index, index_next), dim=1)
    return index

The index[:, -block_size:] crop is critical. The model’s positional embeddings only cover block_size positions (128). If the generated sequence grows beyond 128 characters, only the last 128 are passed through the model. This is a sliding window in which the model always sees the most recent context, forgetting anything further back.

torch.multinomial samples from the probability distribution rather than always taking the argmax. This randomness is what makes generation interesting. The same prompt can produce different completions. Temperature and top-k sampling are common modifications to control this randomness, but the base mechanism is pure multinomial sampling.


Training Loop

The training loop follows the same pattern as the bigram model: forward pass, loss, backward pass, optimizer step:

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_iters == 0:
        losses = estimate_loss()
        print(f"step: {iter}, train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")

    xb, yb = get_batch('train')
    logits, loss = model.forward(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

Every 500 iterations, the model evaluates on both training and validation splits to track overfitting. The estimate_loss function averages over 500 batches with gradients disabled. This produces a stable loss estimate without the noise of individual batches.

After training, the model is saved with pickle:

with open('model-01.pkl', 'wb') as f:
    pickle.dump(model, f)

The saved model can be loaded for inference without retraining — model.py handles this with a load_model function that reads the pickle and moves the model to the correct device.


Inference Mode

model.py doubles as both the model definition and an interactive inference CLI:

if __name__ == "__main__":
    m = load_model()

    while True:
        prompt = input("Prompt:\n")
        context = torch.tensor(encode(prompt), dtype=torch.long, device=device)
        generated_chars = decode(m.generate(context.unsqueeze(0), max_new_tokens=150)[0].tolist())
        print(f'Completion:\n{generated_chars}')

Type a prompt, and the model generates 150 characters of continuation. The output isn’t eloquent english literature, this is a character-level model with a 128-token context window trained on a consumer laptop. But it learns real patterns: word boundaries, common letter combinations, sentence structure, and punctuation placement. It mumbles, but it mumbles in English.


Thanks for reading this far, this is by far the most information-rich article I’ve written. Maybe run a training loop and read it over a few times just incase.

This project is open source at micro-gpt.*