↑↓ or ←→ to navigate
TA walkthrough of Part 1 — building an autoregressive Transformer from scratch, training on TinyStories & MNIST.
Implement Mlp → Attention → Block → Trunk → GPT model. Each piece builds on the last.
class Mlp(nn.Module):
def __init__(self, in_features: int,
hidden_features=None, out_features=None,
bias: bool = False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class Attention(nn.Module):
def __init__(self, dim: int, head_dim: int = 64,
qkv_bias: bool = False, proj_bias: bool = False):
super().__init__()
self.num_heads = dim // head_dim
self.scale = head_dim ** -0.5
# Fused QKV projection — one Linear, 3× output dim
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias)
def forward(self, x, mask=None):
B, L, D = x.shape
# Fused QKV → reshape → (3, B, H, L, d_k)
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads,
D // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # each: (B, H, L, d_k)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, L, L)
if mask is not None:
mask = rearrange(mask, 'b n m -> b 1 n m')
attn = attn.masked_fill(~mask,
-torch.finfo(attn.dtype).max)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, L, D)
return self.attn_out_proj(x)
class Block(nn.Module):
def __init__(self, dim: int, head_dim: int = 64,
mlp_ratio: float = 4.0, use_bias: bool = False):
super().__init__()
self.norm1 = LayerNorm(dim)
self.attn = Attention(dim, head_dim=head_dim,
qkv_bias=use_bias, proj_bias=use_bias)
self.norm2 = LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim, bias=use_bias)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask) # Xₐ = X + Attn(LN₁(X))
x = x + self.mlp(self.norm2(x)) # Xᵦ = Xₐ + MLP(LN₂(Xₐ))
return x
class TransformerTrunk(nn.Module):
def __init__(self, dim: int = 512, depth: int = 8,
head_dim: int = 64, mlp_ratio: float = 4.0,
use_bias: bool = False):
super().__init__()
self.blocks = nn.ModuleList([
Block(dim=dim, head_dim=head_dim,
mlp_ratio=mlp_ratio, use_bias=use_bias)
for i in range(depth)
])
def forward(self, x, mask=None):
for block in self.blocks:
x = block(x, mask=mask)
return x
class GPT(nn.Module):
def __init__(self, seq_read_key: str = 'input_ids',
dim: int = 512, depth: int = 8, head_dim: int = 64,
mlp_ratio: float = 4.0, use_bias: bool = False,
vocab_size: int = 10000, max_seq_len: int = 256,
padding_idx: int = -100, init_std: float = 0.02):
super().__init__()
self.seq_read_key = seq_read_key
self.padding_idx = padding_idx
self.max_seq_len = max_seq_len
self.init_std = init_std
self.input_embedding = nn.Embedding(vocab_size, dim)
self.positional_embedding = nn.Parameter(
torch.randn(max_seq_len, dim))
self.trunk = TransformerTrunk(dim=dim, depth=depth,
head_dim=head_dim, mlp_ratio=mlp_ratio,
use_bias=use_bias)
self.out_norm = LayerNorm(dim, bias=use_bias)
self.to_logits = nn.Linear(dim, vocab_size, bias=False)
self.initialize_weights()
def forward_model(self, x):
B, L = x.size()
# Input + positional embedding
x = self.input_embedding(x) + self.positional_embedding[:L]
# Causal mask: tril(-inf matrix).bool() → True=attend, False=masked
causal_mask = torch.tril(torch.full(
(L, L), float('-inf'), device=x.device), diagonal=0).bool()
causal_mask = causal_mask.unsqueeze(0) # (1, L, L)
x = self.trunk(x, mask=causal_mask)
return self.to_logits(self.out_norm(x)) # (B, L, vocab_size)
def compute_ce_loss(self, logits, target_seq, padding_idx=-100):
B, L, vocab_size = logits.shape
return F.cross_entropy(
logits.reshape(-1, vocab_size),
target_seq.reshape(-1),
ignore_index=padding_idx
)
def forward(self, data_dict):
seq = data_dict[self.seq_read_key] # (B, L+1)
input_seq = seq[:, :-1] # [SOS, t1, …, tL] — drop last
target_seq = seq[:, 1:] # [t1, …, EOS] — drop first (shift!)
logits = self.forward_model(input_seq)
loss = self.compute_ce_loss(logits, target_seq, self.padding_idx)
return loss, {'ppl': torch.exp(loss)}
@torch.no_grad()
def generate(self, context=[0], eos_idx=None,
temp=1.0, top_p=0.0, top_k=0.0):
was_training = self.training
self.eval()
# Initialize the sequence with the start-of-sequence token
current_tokens = torch.tensor(
[context], dtype=torch.long, device=self.device)
for _ in range(self.max_seq_len - len(context)):
logits = self.forward_model(current_tokens)
# Get logits for the last token and reshape to (1, vocab_size)
next_logits = logits[0, -1, :].unsqueeze(0)
next_token, _ = sample_tokens(next_logits,
temperature=temp, top_k=top_k, top_p=top_p)
current_tokens = torch.cat(
[current_tokens, next_token.unsqueeze(1)], dim=1)
if next_token.item() == eos_idx:
break
if was_training:
self.train()
return current_tokens
Subword tokenizer (BPE) — splits text into sub-word units. Vocabulary of 50,257 tokens base, extended with 3 special tokens:
# "This is an example."
[50258, 1212, 318, 281, 1672, 13, 50259,
50257, 50257, 50257]
# [SOS] This is an example [EOS] [PAD]×3
# "Once upon a time..."
[50258, 7454, 2402, 257, 640, 612,
373, 257, 2068, 50259]
# [SOS] Once upon ... quick [EOS]
W&B: eval/loss converges to ~1.3, perplexity ~3.7
# sample_tokens applies temp like:
logits = logits / temp
probs = F.softmax(logits, dim=-1)
token = torch.multinomial(probs, 1)
Same GPT model, different data. Tokenize MNIST images → train autoregressive model → class-conditional image generation.
Provide class label as first token → model generates remaining 49 patch tokens autoregressively
Target: val loss < 0.45
DALL-E 1 (2021) showed this works at scale. The key insight: treat text and image tokens as one unified sequence.
"There are two trains on parallel tracks." — AI generated