Bigram Language Model
Building a Language Model from Scratch with PyTorch
Most every CS student I know took some sort of data science course during their time as an undergrad. We’ve seen the linear regressions, k-means clusters, blah blah blah. I was pretty sure I knew how to fit a model.
What I didn’t understand, both due to recency and stubbornness, was language models. The elephant in the room was the new surge of large language models (LLMs) which we all know and love.
Baby steps.
Bigram Language Model is a character-level language model built in PyTorch that learns to generate text by predicting the next character in a sequence. Trained on the full text of The Wizard of Oz, yeah I know, It starts from pure noise and gradually learns English-like patterns through nothing more than statistical prediction. This post walks through how it works, from raw text to generated output.
The first step for me in learning how these robots can speak. I really do recommend reading this article fully, you may (or may not) learn something.
The Core Idea: Next-Character Prediction
A language model does one thing: given some context, predict what comes next. A bigram model is the simplest version of this. It predicts the next character based only on the current one.
The entire model boils down to a lookup table. For every character in the vocabulary, there’s a row of scores (logits) representing how likely each other character is to follow. The model learns these scores from data.
Before training, the table is completely random. Essentially monkeys on a typewriter, but they do not output Shakespeare.
zr;:bfavYZaBUaT2"[dS& CJ*gt:T6CsuP
S1*g(5:SYShk0kAEdzrNZjp(&[*WT,5raVE
After 100,000 training steps, it produces something recognizably English-adjacent:
"Nowhestrsqueme f The walefa eca thinthe.
bearabie os, oroofryerethem ey he an oucoke
Yet to cure cancer, but for a model that only looks at one character at a time, the fact that it learns word boundaries, punctuation, and rough spelling is impressive.
Character-Level Tokenization
Production language models use sophisticated tokenizers like BPE that break text into subword chunks. This model is far more simple, giving every unique character is its own token.
chars = sorted(set(text))
vocab_size = len(chars) # 81 unique characters
string_to_int = { ch:i for i,ch in enumerate(chars) }
int_to_string = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])
The vocabulary is 81 characters — letters, digits, punctuation, and whitespace. encode converts a string to a list of integers. decode reverses it. The entire text becomes a single long tensor of integer indices:
data = torch.tensor(encode(text), dtype=torch.long)
This is the rawest possible representation of text for a neural network. No subwords, no special tokens, no padding. Just characters mapped to numbers.
For perspective, production models like GPT use tokenizers that read in chunks. “Unbelievable” might be three tokens: “un”, “believ”, “able.” Our model spells it out one letter at a time. It’s slower and less efficient, but the simplicity is the point.
Batches and Blocks
Training requires pairs of inputs and targets. The input is a sequence of characters; the target is the same sequence shifted by one position. Every character in the input maps to a “correct next character” in the target.
In this context we could use a flashcard deck where the front is a partial sentence and the back is the next letter. “Dorot” -> “h”. “Doroth” -> “y”. The model sees thousands of these per second.
block_size = 8 # context length
batch_size = 4 # sequences per batch
A single training example with block_size=8 encodes eight separate predictions:
when input is tensor([1]) target is tensor(1)
when input is tensor([1, 1]) target is tensor(28)
when input is tensor([1, 1, 28]) target is tensor(39)
when input is tensor([1, 1, 28, 39]) target is tensor(42)
...
The get_batch function samples random positions from the dataset and extracts these input-target pairs:
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size, ))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
The data is split 80/20 between training and validation. Random sampling with torch.randint means each batch covers different parts of the text, giving the model broad exposure across epochs.
The Model
The entire model is a single embedding layer:
class BigramLanguageModel(nn.Module):
def __init__(self, vocab_size):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
def forward(self, index, targets=None):
logits = self.token_embedding_table(index)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
nn.Embedding(vocab_size, vocab_size) creates an 81x81 matrix. Each row is the model’s learned prediction for what follows a given character. Looking up character 28 (“D”) returns 81 logits — one score for every possible next character.
The forward pass reshapes the batch and time dimensions into a flat list of predictions, then computes cross-entropy loss against the true next characters. Cross-entropy measures the gap between the model’s probability distribution and reality, the optimizer’s job is to shrink this.
If you’re not a math person: cross-entropy is essentially the model’s confidence score in reverse. When the model says “the next character is definitely ‘e’” and it is ‘e’, the loss is low. When it confidently predicts ‘e’ and the answer is ‘z’, the loss is high. It’s the mathematical equivalent of being wrong and loud about it.
Generation
Text generation is autoregressive. Produce one character, feed it back in, repeat:
def generate(self, index, max_new_tokens):
for _ in range(max_new_tokens):
logits, loss = self.forward(index)
logits = logits[:, -1, :] # last time step only
probs = F.softmax(logits, dim=-1) # convert to probabilities
index_next = torch.multinomial(probs, num_samples=1) # sample
index = torch.cat((index, index_next), dim=1)
return index
At each step, the model produces logits for the last position, converts them to a probability distribution with softmax, and samples from that distribution. Sampling (rather than always picking the highest-probability character) introduces randomness, which makes the output more varied and natural.
Training
The training loop runs for 100,000 iterations using the AdamW optimizer:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for iter in range(max_iters):
if iter % eval_iters == 0:
losses = estimate_loss()
print(f"step: {iter}, train loss: {losses['train']:.4f}, val loss: {losses['val']:.4f}")
xb, yb = get_batch('train')
logits, loss = model.forward(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
Each iteration: fetch a batch, compute predictions and loss, zero out old gradients, backpropagate, update weights. The AdamW optimizer combines adaptive learning rates with proper weight decay. It’s the standard choice for training neural networks and the same optimizer family used in transformer training.
100,000 iterations sounds like a lot. It is. But each iteration only processes a tiny 4x8 batch, so the model needs that many reps to see enough of The Wizard of Oz to learn anything useful. Imagine trying to learn a language by reading random 8-character snippets of a single book. You need a lot of snippets.
Loss Evaluation
Loss is estimated by averaging over 250 batches with gradients disabled:
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
The model.eval() / model.train() toggle matters for models with dropout or batch normalization. This bigram model doesn’t use either, but it’s good practice. Once the architecture becomes more complex this is essential.
Averaging over many batches produces a stable loss estimate. A single batch’s loss is noisy.
Hardware: Apple Silicon with MPS
CUDA is the standard for GPU-accelerated PyTorch, but it requires an NVIDIA GPU. On Apple Silicon Macs, the Metal Performance Shaders (MPS) backend provides GPU acceleration instead:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
This single line routes all tensor operations to the GPU. The training data, model weights, and intermediate computations all live on the same device.
This section is very exclusive to my hardware. So far the info has been general but take note of this if attempting to recreate on your own.
Bigram to Transformer
A bigram model has a fundamental limitation: it only considers the current character. It can learn that “q” is usually followed by “u”, but it can’t learn that “ing” tends to end words, because it never sees more than one character of context.
This is exactly the limitation that attention mechanisms solve. A transformer’s self-attention lets every position attend to every previous position, capturing long-range dependencies. The progression from this notebook to a full GPT-style model involves adding:
| Component | What It Adds |
|---|---|
| Positional embeddings | Awareness of where tokens sit in the sequence |
| Self-attention | Ability to look at all previous tokens, not just the last one |
| Multi-head attention | Multiple parallel attention patterns |
| Feed-forward layers | Non-linear transformations between attention layers |
| Layer normalization | Training stability at depth |
| Residual connections | Gradient flow through deep networks |
The bigram model is no more than an entry point for learning. It is the simplest possible language model that actually learns from data.
Thanks for reading this far, I know this article was a bit more dense than usual. I wanted to make sure I did the architecture justice.
This project is open source at bigram-language-model.*