↑↓ or ←→ to navigate

01 / 27
nanoGPT
Solutions

TA walkthrough of Part 1 — building an autoregressive Transformer from scratch, training on TinyStories & MNIST.

Part 1 of 3 TA Session EPFL Spring 2026
25.3M parameters · TinyStories · MNIST
02 / 27

What We Cover Today

Part 1 TinyStories — Text Generation
  • 2.1.1 MLP Layer (10 pts)
  • 2.1.2 Masked Self-Attention (20 pts)
  • 2.1.3 Transformer Block (10 pts)
  • 2.1.4 Transformer Trunk (10 pts)
  • 2.1.5 GPT init, forward & loss (20 pts)
  • 2.1.6 Generation Loop (20 pts)
  • 2.3 Loss curves (10 pts)
  • 2.4 Evaluation (10 pts)
  • 2.5 Open-ended Q2.1–Q2.3 (15 pts)
Part 2 MNIST — Image Generation
  • 3.1–3.2 MNIST Tokenization & Training
  • 3.3 Loss curves (10 pts)
  • 3.4 Evaluation (10 pts)
  • 3.5 Open-ended Q3.1–Q3.3 (15 pts)
Part 1: 125 pts  ·  Part 3: 35 pts  ·  Total: 160 pts
03 / 27

Codebase Structure

nano4M/
  cfgs/nanoGPT/  # training configs (yaml)
  nanofm/
    modeling/
      transformer_layers.py  # ← Mlp, Attention, Block, Trunk
    models/
      gpt.py  # ← GPT model, forward, loss, generate
    data/
      text/huggingface_datasets.py
      vision/tokenized_mnist.py
    utils/
      sampling.py  # sample_tokens helper
      checkpoint.py
  run_training.py  # main training loop
  setup_env.sh  # creates nanofm conda env
01
04 / 27
Building
nanoGPT

Implement Mlp → Attention → Block → Trunk → GPT model. Each piece builds on the last.

transformer_layers.py gpt.py
05 / 27 10 pts

Mlp

nanofm/modeling/transformer_layers.py
class Mlp(nn.Module):
    def __init__(self, in_features: int,
            hidden_features=None, out_features=None,
            bias: bool = False):
        super().__init__()
        out_features    = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x
06 / 27 20 pts

Attention — Init & QKV

class Attention(nn.Module):
    def __init__(self, dim: int, head_dim: int = 64,
            qkv_bias: bool = False, proj_bias: bool = False):
        super().__init__()
        self.num_heads = dim // head_dim
        self.scale     = head_dim ** -0.5

        # Fused QKV projection — one Linear, 3× output dim
        self.qkv          = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias)
07 / 27 20 pts

Attention — Forward & Masking

    def forward(self, x, mask=None):
        B, L, D = x.shape

        # Fused QKV → reshape → (3, B, H, L, d_k)
        qkv = self.qkv(x).reshape(B, L, 3, self.num_heads,
                             D // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # each: (B, H, L, d_k)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, H, L, L)
        if mask is not None:
            mask = rearrange(mask, 'b n m -> b 1 n m')
            attn = attn.masked_fill(~mask,
                           -torch.finfo(attn.dtype).max)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, L, D)
        return self.attn_out_proj(x)
08 / 27

Attention — QKV Tensor Shapes

① QKV DECOMPOSITION x (input) (B,L,D) x.shape qkv(x) Linear D→3D fused QKV (B,L,3D) Q+K+V in one dim reshape split 3D after reshape (B,L,3,H,d) 3D → 3×H×d_k permute (2,0,3,1,4) move ③ to front after permute (3,B,H,L,d) ③ is now dim 0 ← key! unbind(0) split on dim 0 Q , K , V each shape: (B,H,L,d) 3 separate tensors ✓ now use Q, K, V for attention ② ATTENTION COMPUTATION Q @ Kᵀ · scale (B,H,L,L) raw attn scores L×L = each tok vs each causal mask + softmax −inf then exp attention weights (B,H,L,L) rows sum to 1.0 future tokens → 0 @ V weighted sum weighted values (B,H,L,d) one vector per head per token per batch transpose(1,2) → (B,L,H,d) reshape→(B,L,D) output heads merged back: (B,L,D) same shape as input ✓ → attn_out_proj Dim colors: B batch L seq len D model dim 3 Q/K/V H heads d head_dim (d_k) 3D = H×d_k×3
09 / 27 10 pts

Block — Pre-Norm + Residuals

class Block(nn.Module):
    def __init__(self, dim: int, head_dim: int = 64,
            mlp_ratio: float = 4.0, use_bias: bool = False):
        super().__init__()
        self.norm1 = LayerNorm(dim)
        self.attn  = Attention(dim, head_dim=head_dim,
                       qkv_bias=use_bias, proj_bias=use_bias)
        self.norm2 = LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp   = Mlp(in_features=dim,
                      hidden_features=mlp_hidden_dim, bias=use_bias)

    def forward(self, x, mask=None):
        x = x + self.attn(self.norm1(x), mask)  # Xₐ = X + Attn(LN₁(X))
        x = x + self.mlp(self.norm2(x))         # Xᵦ = Xₐ + MLP(LN₂(Xₐ))
        return x
10 / 27 10 pts

TransformerTrunk

class TransformerTrunk(nn.Module):
    def __init__(self, dim: int = 512, depth: int = 8,
            head_dim: int = 64, mlp_ratio: float = 4.0,
            use_bias: bool = False):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim=dim, head_dim=head_dim,
                  mlp_ratio=mlp_ratio, use_bias=use_bias)
            for i in range(depth)
        ])

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask=mask)
        return x
11 / 27 20 pts

GPT — __init__

class GPT(nn.Module):
    def __init__(self, seq_read_key: str = 'input_ids',
            dim: int = 512, depth: int = 8, head_dim: int = 64,
            mlp_ratio: float = 4.0, use_bias: bool = False,
            vocab_size: int = 10000, max_seq_len: int = 256,
            padding_idx: int = -100, init_std: float = 0.02):
        super().__init__()
        self.seq_read_key = seq_read_key
        self.padding_idx   = padding_idx
        self.max_seq_len    = max_seq_len
        self.init_std       = init_std

        self.input_embedding      = nn.Embedding(vocab_size, dim)
        self.positional_embedding = nn.Parameter(
            torch.randn(max_seq_len, dim))
        self.trunk = TransformerTrunk(dim=dim, depth=depth,
                      head_dim=head_dim, mlp_ratio=mlp_ratio,
                      use_bias=use_bias)
        self.out_norm  = LayerNorm(dim, bias=use_bias)
        self.to_logits = nn.Linear(dim, vocab_size, bias=False)
        self.initialize_weights()
12 / 27 20 pts

GPT — forward_model

    def forward_model(self, x):
        B, L = x.size()

        # Input + positional embedding
        x = self.input_embedding(x) + self.positional_embedding[:L]

        # Causal mask: tril(-inf matrix).bool() → True=attend, False=masked
        causal_mask = torch.tril(torch.full(
            (L, L), float('-inf'), device=x.device), diagonal=0).bool()
        causal_mask = causal_mask.unsqueeze(0)  # (1, L, L)

        x = self.trunk(x, mask=causal_mask)
        return self.to_logits(self.out_norm(x))  # (B, L, vocab_size)
13 / 27 20 pts

GPT — compute_ce_loss & forward

    def compute_ce_loss(self, logits, target_seq, padding_idx=-100):
        B, L, vocab_size = logits.shape
        return F.cross_entropy(
            logits.reshape(-1, vocab_size),
            target_seq.reshape(-1),
            ignore_index=padding_idx
        )

    def forward(self, data_dict):
        seq = data_dict[self.seq_read_key]  # (B, L+1)
        input_seq  = seq[:, :-1]           # [SOS, t1, …, tL] — drop last
        target_seq = seq[:, 1:]            # [t1, …, EOS] — drop first (shift!)
        logits = self.forward_model(input_seq)
        loss   = self.compute_ce_loss(logits, target_seq, self.padding_idx)
        return loss, {'ppl': torch.exp(loss)}
14 / 27 20 pts

GPT — generate

    @torch.no_grad()
    def generate(self, context=[0], eos_idx=None,
            temp=1.0, top_p=0.0, top_k=0.0):
        was_training = self.training
        self.eval()

        # Initialize the sequence with the start-of-sequence token
        current_tokens = torch.tensor(
            [context], dtype=torch.long, device=self.device)

        for _ in range(self.max_seq_len - len(context)):
            logits = self.forward_model(current_tokens)
            # Get logits for the last token and reshape to (1, vocab_size)
            next_logits = logits[0, -1, :].unsqueeze(0)
            next_token, _ = sample_tokens(next_logits,
                temperature=temp, top_k=top_k, top_p=top_p)
            current_tokens = torch.cat(
                [current_tokens, next_token.unsqueeze(1)], dim=1)
            if next_token.item() == eos_idx:
                break

        if was_training:
            self.train()

        return current_tokens
15 / 27

GPT-2 Tokenizer + Special Tokens

Subword tokenizer (BPE) — splits text into sub-word units. Vocabulary of 50,257 tokens base, extended with 3 special tokens:

  • [PAD] id=50257 — pads sequences to equal length; loss is ignored on these
  • [SOS] id=50258 — start of sequence; used as unconditional generation seed
  • [EOS] id=50259 — end of sequence; signals the model is done
TemplateProcessing automatically wraps every sequence with [SOS]…[EOS] at encode time.
# "This is an example."
[50258, 1212, 318, 281, 1672, 13, 50259,
 50257, 50257, 50257]
# [SOS] This is an example [EOS] [PAD]×3

# "Once upon a time..."
[50258, 7454, 2402, 257, 640, 612,
 373, 257, 2068, 50259]
# [SOS] Once upon ... quick [EOS]
Truncated if longer than max_length. No [EOS] if cut off.
16 / 27

Training on TinyStories

nanoGPT TinyStories loss curves

W&B: eval/loss converges to ~1.3, perplexity ~3.7

  • Config: tinystories_d8w512.yaml
  • Hardware: 2× V100, ~1 hour
  • Target: val loss ≈ 1.3
  • Params: 25.3M (d=8, w=512)
  • Optimizer: AdamW + cosine LR schedule
  • Checkpoints: auto-saved & auto-resumed
Tip: Debug on MNIST first (Part 3) — same code, trains in minutes.
17 / 27

What Good Generations Look Like

Once upon a time, there was a little girl named Lily. She had a bed that was very cold and dark. Lily's mommy said, "Lily, your bed is too bad. You need to go to sleep, but I am scared to come back outside."

Lily said, "I regret going outside and I can't sleep. I am lost and I miss my mommy."

Mommy hugged Lily and said, "Don't worry, we will find you soon. And maybe we can read happy stories together."

Lily felt happy and safe in her bed. She said, "Thank you, mommy. You promised me we will always be together." And they both smiled and snuggled together, feeling safe by each other's side until it was time for bed.
Daisy was hungry, so she went to the kitchen. She saw a big bowl of warm soup on the table. Her mommy said, "I made it just for you!"

Daisy sat down and ate every last drop. It was so yummy and warm. When she was done, she gave her mommy a big hug and said, "Thank you, mommy. You always make the best food."

Her mommy smiled and hugged her back. "I love you, Daisy," she said. Daisy felt happy and full. She went back to play and did not feel hungry anymore. It was a very good day.
  • Grammatically coherent sentences
  • Stories have a beginning and end
  • Conditional context is followed through
  • No degenerate repetition loops
10 pts: Show loss curves screenshot in section 2.3. Target: val loss ≈ 1.3
18 / 27

Q2.1 — Effect of Temperature

Q: What effect does the temperature have on the generations?
Answer
Temperature τ divides the logits before softmax: logits / τ. This scales the sharpness of the probability distribution.
  • τ → 0: argmax (always pick most likely token) — deterministic, repetitive
  • τ = 1: standard sampling — default, balanced
  • τ → ∞: uniform distribution — purely random, incoherent
# sample_tokens applies temp like:
logits = logits / temp
probs  = F.softmax(logits, dim=-1)
token  = torch.multinomial(probs, 1)
Good range: τ ∈ [0.7, 1.0] for text. Lower → more coherent but less diverse.
19 / 27

Q2.2 — top_k and top_p

Effect
Keep only the k most probable tokens; zero out the rest before softmax. Prevents sampling from the long tail of improbable tokens.
  • top_k=0 → disabled
  • top_k=50 → consider only top 50 tokens
Effect
Keep the smallest set of tokens whose cumulative probability ≥ p. The size of the set adapts to the distribution — tighter when the model is confident.
  • top_p=0 → disabled
  • top_p=0.9 → keep 90% probability mass
Combination: top_p is generally preferred over top_k because it adapts dynamically. Both can be combined with temperature.
20 / 27

Q2.3 — How to Improve the Model?

  • Bigger model: more layers (depth), larger width (dim) — Chinchilla scaling laws
  • More data: TinyStories is small; more diverse training data helps
  • Longer training: more steps, better LR schedule
  • RoPE positional embeddings for longer context
  • Flash Attention for faster training
  • Better tokenizer: e.g. SentencePiece subword tokenization
  • RLHF / fine-tuning: align outputs to human preferences
  • Data quality: filter low-quality text
Reference: GPT-3 (175B), LLaMA 3 (70B+) follow the same architecture — just scaled up massively.
02
21 / 27
nanoGPT on
Images

Same GPT model, different data. Tokenize MNIST images → train autoregressive model → class-conditional image generation.

iGPT-style No code changes needed
22 / 27

MNIST → Discrete Tokens

  • Step 1: 28×28 grayscale → threshold → 14×14 binary image
  • Step 2: 14×14 → 7×7 patches of 2×2 → each patch = token 0–15 → 49 tokens
  • Step 3: Prepend class label (0–9) → total sequence: 50 tokens
  • Token shift: patch values +10 to avoid overlap with class tokens (0–9)
MNIST 7x7 patch tokenization grid
Why patch? 196 raw pixels → 49 patch tokens (4× shorter sequence).
23 / 27

Class-Conditional Generation

Generated MNIST digits — 10 samples per class

Provide class label as first token → model generates remaining 49 patch tokens autoregressively

  • Each row = one digit class (0–9)
  • 10 samples per class
  • Most samples should be recognizable
  • Some diversity across samples
nanoGPT MNIST training loss curves

Target: val loss < 0.45

10 pts for section 3.3
Show loss curves (val loss < 0.45). Show generated samples (should look like recognizable digits).
24 / 27

Q3.1 & Q3.2 — Sampling on Images

Q3.1: Temperature on images
Same principle as text. Low T → sharp, clean but less diverse digits. High T → noisy, malformed shapes. Optimal around 0.7–0.9.
Key difference from text
Image tokens have a much smaller vocabulary (26 tokens vs 50k). So sampling is less sensitive to temperature compared to text.
Q3.2: top_k / top_p on images
With only 16 image token types, top_k=16 and top_p=1.0 are equivalent to no truncation. Effective top_k would be very small (e.g. 3–5) to restrict to likely patches.
Experiment: try temp=0.3 vs temp=1.5 on MNIST and compare digit clarity vs diversity.
25 / 27

Q3.3 — Extending to Text-to-Image

DALL-E 1 (2021) showed this works at scale. The key insight: treat text and image tokens as one unified sequence.

  • Encode text with a BPE tokenizer into discrete tokens.
  • Encode image into discrete tokens using VQ-VAE / VQ-GAN.
  • Concatenate [text tokens | image tokens] into a single sequence.
  • Train GPT with next-token-prediction loss to predict image tokens given text.
  • At inference, feed text → GPT autoregressively generates image tokens → decode to pixels.
DALL-E 1: Ramesh et al. 2021 — 12B-parameter GPT-3 + VQ-VAE at scale.
DALL-E generated trains image

"There are two trains on parallel tracks." — AI generated

26 / 27

Papers & Codebases

  • Attention Is All You Need — Vaswani et al. 2017
  • GPT-2 — Radford et al. 2019
  • GPT-3 — Brown et al. 2020
  • LLaMA 3 — Meta AI, 2024
  • Scaling Laws — Kaplan 2020 · Chinchilla 2022
  • DALL-E 1 — Ramesh et al. 2021
  • nanoGPT — Karpathy's minimal GPT implementation
  • LlamaGen — autoregressive image generation
  • torchtitan — PyTorch production LLM training
  • The Illustrated Transformer — Jay Alammar
  • Transformer Family v2.0 — Lilian Weng
27 / 27
Key Things
to Remember
  • Transformers = stacked Attention + MLP blocks with residual connections and pre-norm
  • Causal masking via lower-triangular boolean mask — True = attend, False = −∞
  • Autoregressive generation = forward pass → sample last logit → append → repeat
  • Tokens are universal — same architecture works for text AND images
  • Temperature, top_k, top_p control the sharpness and diversity of sampling
Part 2 next: MaskGIT →