No speaker notes for this slide.

↑↓ or ←→ to navigate

01 / 39
nanoFlow
Solutions

TA walkthrough of Part 4 — Rectified Flow Matching, DiT-LLaMA architecture, and class-conditional image generation on MNIST & CIFAR-10.

Part 4 of 5 TA Session EPFL Spring 2026
Rectified Flow · DiT-LLaMA · MNIST · CIFAR-10 · 140 pts total
02 / 39

Score breakdown (140 pts)

Architecture (50 pts)
  • §6.1 modulate() helper — 5 pts
  • §6.2 LabelEmbedder.token_drop() — 5 pts
  • §6.3 TransformerBlock AdaLN forward — 15 pts
  • §6.4 FinalLayer.forward() — 5 pts
  • §6.5 DiT_Llama.forward() — 15 pts
  • §6.6 DiT sanity check — 5 pts
Training (35 pts)
  • §7.1 Sample timesteps — 10 pts
  • §7.2 Compute interpolation zt — 10 pts
  • §7.3 Velocity target + MSE loss — 10 pts
  • §8.3 Loss curve plot — 5 pts
Inference (25 pts)
  • §11.1 Conditional velocity prediction — 5 pts
  • §11.2 Classifier-free guidance — 5 pts
  • §11.3 Euler integration step — 5 pts
  • §12–14 Generated samples + trajectory — 10 pts
03 / 39

Codebase structure

nano4M/ └── nanofm/ ├── modeling/ │ └── dit.py ← DiT_Llama, TransformerBlock, FinalLayer, modulate() ├── models/ │ └── rectified_flow.py ← RectifiedFlow.forward() + sample() — your target └── notebooks/ └── COM304_FM_part4_nanoFlowMatching.ipynb
I
Flow Matching
Theory

Rectified flow, straight-line interpolation, velocity fields, and why it beats DDPM.

05 / 39

Why flow matching?

DDPM (denoising diffusion)
  • Curved trajectories through data space
  • Noise schedule: β₁…β_T (1000 steps typical)
  • Predicts noise ε or score ∇log p(x)
  • Complex variance schedule tuning needed
  • Slow sampling due to curved ODE paths
Rectified Flow (flow matching)
  • Straight-line paths from data to noise
  • Single hyperparameter-free interpolation
  • Predicts constant velocity v = x₁ − x₀
  • Converges in 20–50 steps (vs 1000)
  • Simpler loss: just MSE on velocity
TRAJECTORY COMPARISON DDPM x₀ ~1000 curved steps Flow Matching x₀ 20–50 straight steps x₁ ~ 𝒩(0, I) (pure noise)
06 / 39

Rectified flow: the math

Interpolation:   zt = (1−t)·x0 + t·x1   where   x0 ∼ data,   x1 ∼ 𝒩(0,I),   t ∈ [0,1]
Target velocity:   v* = x1 − x0   (constant along the straight path)

Geometric picture

x₀ data (t=0) z_t (t=0.5) x₁~𝒩 noise (t=1) v* = x₁−x₀ (constant!)

Why t ∈ [0,1]?

  • t=0: z0 = x0 (pure data)
  • t=1: z1 = x1 (pure noise)
  • t=0.5: 50/50 mixture — hardest to predict
Sanity check: dx/dt = d/dt[(1−t)x₀ + t·x₁] = x₁−x₀ = v*. The velocity is truly constant — the ODE is trivial to integrate exactly.
07 / 39

The interpolation in action

Each frame shows zt = (1−t)·x0 + t·x1 for one MNIST digit at increasing t. At t=0 the digit is fully visible; at t=1 it becomes indistinguishable Gaussian noise.

interp_strip
Exercise §7.2 asks you to implement this. The critical region is t≈0.5 where the model must distinguish signal from noise — this is exactly where logit-normal sampling focuses compute.
08 / 39

Training objective

 =  𝔼t, x0, x1  [ ‖  vθ(zt, t, c)  −  (x1 − x0)  ‖2 ]

Term breakdown

vθ(zt, t, c) DiT's predicted velocity at noisy image zt, time t, class c
x1 − x0 Ground-truth velocity — constant along the straight path
‖·‖² MSE averaged over all B×C×H×W elements
𝔼[·] Expectation over t∼p(t), x0∼data, x1∼𝒩(0,I)

Full training step (§7.1–7.3)

1 Sample noise: z1 ~ 𝒩(0, I)
2 Sample timestep: t ~ logit-normal
3 Interpolate: zt = (1−t)·x + t·z1
4 Forward pass: v̂ = DiT(zt, t, c)
5 Loss: ‖v̂ − (z1−x)‖²
09 / 39

Logit-normal timestep sampling

sample  n ∼ 𝒩(0,1) , then  t = σ(n) = 1 / (1 + e−n)

Why not uniform t?

  • Near t=0: zt ≈ x0 — trivially easy, model learns nothing new
  • Near t=1: zt ≈ pure noise — also easy to fit
  • At t≈0.5: signal and noise are equally mixed — hardest region, most gradient signal
  • Uniform wastes 60% of training steps on easy regions

What σ(·) gives us

  • σ squashes ℝ → (0,1) and peaks the density near 0.5
  • Tails (t near 0,1) still covered — just undersampled
  • Empirically: lower final loss, better FID at same epochs

§7.1 solution

def forward(self, x, cond): b = x.size(0) # logit-normal sampling if self.ln: nt = torch.randn((b,)).to(x.device) t = torch.sigmoid(nt) # [B] else: t = torch.rand((b,)).to(x.device) # reshape for broadcasting with [B,C,H,W] texp = t.view([b, *([1]*len(x.shape[1:]))]) ...
self.ln = True in all notebook experiments. The uniform fallback exists only for ablation comparisons.
10 / 39

Uniform vs logit-normal distributions

Both histograms show 10,000 sampled timesteps t ∈ [0,1]. The left is flat (uniform); the right peaks near t≈0.5 — exactly the hardest region for the model.

logit_hist
Takeaway: Logit-normal oversamples the middle of the trajectory (t≈0.3–0.7) where data and noise are mixed. This concentrates gradient signal where it matters most, similar to focal loss in classification.
II
DiT-LLaMA
Architecture

Diffusion Transformer with adaptive layer norm, RoPE, SwiGLU, and class conditioning via CFG dropout.

12 / 39

DiT-LLaMA: full forward pass

Input z_t [B,C,H,W] Patchify patch=2 [B,N,dim] Timestep t + Class c N × Block AdaLN+Attn+FFN FinalLayer AdaLN+Linear Unpatchify [B,C,H,W] Output v̂ velocity field Conditioning (t+c) flows into every TransformerBlock via AdaLN

Patchify / Unpatchify

  • 2×2 patches → token sequence [B, N, dim]
  • MNIST 28×28 → 14×14 = 196 tokens
  • FinalLayer maps dim → patch²×C, then reshape back

Conditioning

  • t → sinusoidal embed → MLP
  • c → learned embed (null for CFG)
  • cond = t_embed + c_embed → every block
5 pts
13 / 39

§6.1 modulate() — adaptive layer norm

modulate(x, shift, scale) = x · (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def modulate(x, shift, scale): # x: [B, N, dim] — token sequence # scale: [B, dim] — from adaLN MLP # shift: [B, dim] — from adaLN MLP # unsqueeze(1): [B,dim] → [B,1,dim] for broadcast return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
Why (1 + scale) and not just scale?
At init the adaLN MLP output ≈ 0, so scale≈0 → (1+scale)≈1. modulate acts like identity: x·1+0 = x. Conditioning is learned incrementally from a neutral start — much more stable.
Key: unsqueeze(1) is done inside modulate. scale/shift arrive as [B, dim] from the MLP — do NOT unsqueeze before passing them in.
AdaLN MODULATION LayerNorm(x) × + (1 + scale) shift cond (t+c) → MLP modulated output
5 pts
14 / 39

§6.2 token_drop() — CFG training trick

def token_drop(self, labels, force_drop_ids=None): # During training: randomly replace class label # with "null" token (= num_classes) with prob p_uncond if force_drop_ids is None: drop_ids = torch.rand( labels.shape[0], device=labels.device ) < self.dropout_prob else: drop_ids = force_drop_ids == 1 labels = torch.where( drop_ids, self.num_classes, # null token index labels ) return labels

Why this enables CFG

  • During training, ~10% of samples get label replaced with null token
  • Model learns both conditional v(z,t,c) and unconditional v(z,t,∅)
  • At inference, run model twice: once with c, once with null → apply guidance
What is force_drop_ids for?
Forces all labels to the null token — used to get the unconditional prediction v(z,t,∅) at inference without retraining.
15 / 39

TimestepEmbedder — scalar t → vector

Scalar t e.g. 0.35 shape (B,) Sinusoidal expansion cos(t·ω₀)…cos(t·ω₁₂₇) sin(t·ω₀)…sin(t·ω₁₂₇) ωₖ = exp(−log(10000)·k/128) → (B, 256) MLP Linear(256→D) ↓ SiLU ↓ Linear(D→D) t_emb (B, D) D=min(dim,1024) + y_emb (B,D) adaln _input (B,D)
@staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 # e.g. 128 freqs = torch.exp( -math.log(max_period) * torch.arange(0, half) / half ).to(t.device) # [128] geometric args = t[:, None] * freqs[None] # [B,128] embedding = torch.cat( [torch.cos(args), torch.sin(args)], dim=-1 ) # [B, 256] def forward(self, t): t_freq = self.timestep_embedding( t, self.frequency_embedding_size) return self.mlp(t_freq) # [B, D]
Left (low k, ω≈1): curves spread out — different t gives different values. Right (high k, ω≈0): all near 0. Low-freq dims carry the timing information.

sin(t · ωₖ) across embedding dims — each curve = one t value

t=0.05 t=0.2 t=0.4 t=0.6 t=0.8 t=0.95 Embedding dimension k (0=low freq → 127=high freq) sin value 1.0 0.5 0 ← distinguishes t values ≈ 0 for all t
Left (low k, ω≈1): curves spread out — different t → different values. Right (high k, ω≈0): all near 0. Low-freq dims encode coarse position; together they uniquely fingerprint every t.
16 / 39

RoPE — rotary position embedding

Attention is permutation-invariant — without a position signal, patch (row 3, col 5) looks identical to patch (0, 0). RoPE injects position by rotating the Q and K vectors by an angle proportional to where the patch sits.

① PICK AN ANGLE
θp,k = p · ωk
Position p times a fixed frequency ωk. Geometric schedule like sinusoidal PE — zero learnable parameters.
② ROTATE Q & K
(q2k, q2k+1) → Rθ · (q2k, q2k+1)
2D rotation applied to every adjacent pair of dims, on Q and K, inside attention — before the dot product.
③ ABSOLUTE → RELATIVE
Q · KT ∝ cos((pq−pk) · ωk)
Absolute angles cancel in the inner product. Attention sees only relative position — and extrapolates to unseen lengths.
SAME PAIR · FOUR POSITIONS
q₂ₖ q₂ₖ₊₁ p=0 p=1 p=2 p=3 ωk Each unit of p rotates by ωk. Different frequency bands rotate at different speeds.

2D images → split the head

  • First half of head dim encodes the row, second half encodes the column
  • self.freqs_cis precomputes (cos θ, sin θ) for every (row, col) at __init__ — a buffer, not a parameter
  • apply_rotary_emb(q, k, freqs_cis) applies the rotation inside each attention block
§6.5 gotcha: move freqs_cis to x.device before the layer loop — otherwise device-mismatch crash on GPU. Bug #1 on slide 21.
17 / 39

AdaLN — why this design?

DiT conditioning strategy comparison

Peebles & Xiao, Scalable Diffusion Models with Transformers (2023)

In-Context Conditioning
tokens = [t, y, x₁…xₙ]
standard self-attn
  • No extra parameters
  • Simple to implement
✗ Condition token competes with patch tokens for attention budget — signal dilutes at depth
Cross-Attention
Q=x, K=V=c
extra attn block/layer
  • Rich token↔cond interaction
  • Works well for text (LDMs)
✗ +O(D²) params per layer — expensive for class-conditioned image generation
adaLN-Zero
our choice
shift,scale,gate = MLP(t+y)
modulate(LN(x), …)
  • Only +6×D params/block
  • Zero-init → stable training
  • Each block specializes independently
✓ Best FID with fewest extra params (per DiT paper ablations)
15 pts
18 / 39

§6.3 TransformerBlock — AdaLN forward

def forward(self, x, freqs_cis, adaln_input=None): if adaln_input is not None: # Step 1: get 6 params — dim=1, not dim=-1! shift_msa, scale_msa, gate_msa, \ shift_mlp, scale_mlp, gate_mlp = ( self.adaLN_modulation(adaln_input) .chunk(6, dim=1) ) # Step 2: AdaLN-modulated attention x = x + gate_msa.unsqueeze(1) * self.attention( modulate(self.attention_norm(x), shift_msa, scale_msa), # NO .unsqueeze here freqs_cis ) # Step 3: AdaLN-modulated FFN x = x + gate_mlp.unsqueeze(1) * self.feed_forward( modulate(self.ffn_norm(x), shift_mlp, scale_mlp) # NO .unsqueeze here ) else: x = x + self.attention(self.attention_norm(x), freqs_cis) x = x + self.feed_forward(self.ffn_norm(x)) return x

6 modulation params

  • shift_msa, scale_msa — AdaLN for attention sublayer
  • gate_msa — scalar gate on attention residual output
  • shift_mlp, scale_mlp — AdaLN for FFN sublayer
  • gate_mlp — scalar gate on FFN residual output
Two pitfalls: (1) .chunk(6, dim=1) not dim=-1. (2) Pass shift/scale directly to modulate() — it adds .unsqueeze(1) internally.
5 pts
19 / 39

§6.4 FinalLayer — last projection

def forward(self, x, c): # dim=1, no .unsqueeze — modulate does it shift, scale = \ self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) # Project to patch output dim x = self.linear(x) return x # x shape: [B, N, patch_size² × out_channels]

What it does

  • Same AdaLN as TransformerBlock but only 2 params (no gate)
  • Linear layer maps dim → patch_size² × out_channels
  • For MNIST: 128 → 4×1 = 4 (grayscale, patch=2)
  • For CIFAR: 256 → 4×3 = 12 (RGB, patch=2)
Why no gate in FinalLayer?
No residual connection in the output layer — it projects directly to pixel space, so there is nothing to gate.
15 pts
20 / 39

§6.5 DiT_Llama.forward() — assembly

def forward(self, x, t, y): # Move RoPE to correct device first self.freqs_cis = self.freqs_cis.to(x.device) # 1. Shallow CNN features, then patch tokens x = self.init_conv_seq(x) # [B, dim//2, H, W] x = self.patchify(x) # [B, N, patch²×dim//2] x = self.x_embedder(x) # [B, N, dim] # 2. Conditioning vector: t + y → adaln_input t = self.t_embedder(t) # [B, dim] y = self.y_embedder( # [B, dim] y, self.training) # ← token_drop! adaln_input = t.to(x.dtype) + y.to(x.dtype) # 3. Transformer stack for layer in self.layers: x = layer(x, self.freqs_cis[:x.size(1)], adaln_input=adaln_input) # 4. Project + reshape x = self.final_layer(x, adaln_input) x = self.unpatchify(x) # [B,C,H,W] return x
input x (B, 1, 28, 28) init_conv_seq CNN features (B, 128, 28, 28) patchify 196 patch tokens (B, 196, 512) x_embedder projected to model dim (B, 196, 256) t: (B,) to (B, 256) y: (B,) to (B, 256) adaln_input (B, 256) x 16 TransformerBlock shape unchanged (B, 196, 256) FinalLayer per-patch velocity (B, 196, 4) unpatchify predicted velocity v_theta (B, 1, 28, 28)
III
Training
the Model

Implement the RectifiedFlow training loop: sample t, interpolate, compute velocity target, minimize MSE.

10 pts
22 / 39

§7.1 Sample timesteps

def forward(self, x, cond): b = x.size(0) # §7.1 — Sample one t per image in batch if self.ln: # Logit-normal: concentrate near t=0.5 nt = torch.randn((b,)).to(x.device) t = torch.sigmoid(nt) # shape [B] else: # Uniform fallback t = torch.rand((b,)).to(x.device) # texp for broadcasting with image dims texp = t.view([b, *([1] * len(x.shape[1:]))]) # [B,1,1,1]

Shape details

  • t shape: [B] — one timestep per image (not per pixel)
  • Each image in the batch gets its own random t — more diverse training signal per step
  • texp: [B,1,1,1] — reshaping needed to broadcast with image [B,C,H,W]
  • Use t.view([b, *([1]*len(x.shape[1:]))]) — generalizes beyond 2D images
10 pts
23 / 39

§7.2 Compute zt — the noisy image

# §7.2 — Sample noise and interpolate z1 = torch.randn_like(x) # Gaussian noise, same shape as x # Linear interpolation: z_t = (1-t)*x + t*z1 zt = (1 - texp) * x + texp * z1 # zt shape: [B, C, H, W] # zt is the model input — noisy image at time t
Why torch.randn_like(x)?
randn_like copies shape AND device/dtype from x automatically — no risk of device mismatch or dtype issues. The noise variable is z1 (not x1) matching the convention: z₀=data, z₁=noise.

Verify with the plot

At t=0: zt = x (pure data). At t=1: zt = z1 (pure noise). At t=0.5: equal mix. This matches the interpolation strip shown earlier.

Note: x is the clean data (z₀), z1 is the sampled noise (z₁). The interpolation z_t moves from x toward z1 as t increases.
10 pts
24 / 39

§7.3 Velocity target + MSE loss

# §7.3a — Model forward: predict velocity vtheta = self.model(zt, t, cond) # target is z1 - x (direction: data → noise) # §7.3b — Batchwise MSE over C×H×W dims batchwise_mse = ( (z1 - x - vtheta) ** 2 ).mean(dim=list(range(1, len(x.shape)))) # batchwise_mse shape: [B] # §7.3c — Logging and return tlist = batchwise_mse.detach().cpu()\ .reshape(-1).tolist() ttloss = [(tv, tl) for tv, tl in zip(t, tlist)] return batchwise_mse.mean(), ttloss

Complete forward() summary

1. t ~ logit-normal [B]
2. z1 = randn_like(x) [B,C,H,W]
3. zt = (1-texp)*x + texp*z1
4. vtheta = model(zt, t, cond)
5. batchwise_mse = mean((z1-x-vtheta)²)
6. return batchwise_mse.mean(), ttloss
Batchwise MSE, not F.mse_loss! The mean is over C×H×W dimensions, keeping the batch dimension to produce per-sample losses for logging.
5 pts §8.3
25 / 39

MNIST training loss curve

Reading the curve

  • Epoch 1: loss ~0.4 — model is random, poor velocity predictions
  • Epochs 1–5: rapid drop — model learns global structure quickly
  • Epochs 5–30: slow refinement — learning fine details and harder timesteps
  • Epoch 30: loss ~0.12 — converged for MNIST
§8.3 deliverable: submit this plot. If your loss doesn't go below ~0.2 by epoch 30, check: (1) learning rate, (2) logit-normal sampling implemented correctly, (3) loss computed with MSE not MAE.
loss_curve
26 / 39

MNIST: generated samples (§12–14, 10 pts)

What to observe

  • Each row = one digit class (0–9), 8 samples per class
  • Generated with steps=50, cfg=2.0
  • High diversity within each class — not memorized
  • Clear class identity in all rows
§12: generate a grid like this
Run rf.sample() for each class, concat results, display with matplotlib. Use steps=50 and cfg=2.0 as defaults.
§13: save a trajectory GIF
Record z_t at each Euler step and save as animated GIF — shows noise gradually resolving into a digit.
mnist_grid
IV
Inference &
Sampling

ODE integration with Euler steps, classifier-free guidance, and how sampling steps + CFG scale affect quality.

28 / 39

ODE integration: Euler method

Forward ODE (training): dz/dt = v*(z,t)  →  z moves data→noise
Backward ODE (sampling): zt−Δt = zt − Δt · vθ(zt, t, c)

Euler sampling loop

# Start from pure noise at t=1 z = torch.randn_like(x_shape) dt = 1.0 / num_steps for i in range(num_steps, 0, -1): t = i / num_steps # t: 1.0 → dt t_tensor = torch.full((B,), t) # Predict velocity at current z, t v = model(z, t_tensor, cond) # Step backwards (noise→data) z = z - dt * v # z is now ~ p(data)

Why only 20–50 steps?

  • Rectified flow paths are straight lines — Euler method is exact on straight lines
  • With enough steps, Euler error is negligible
  • DDPM needs 1000 steps because curved ODE paths require small Δt to stay accurate
  • Flow matching can even work in 1 step (with quality trade-off)
5 pts
29 / 39

§11.1 Conditional velocity prediction

@torch.no_grad() def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0): b = z.size(0) dt = 1.0 / sample_steps dt = torch.tensor([dt]*b).to(z.device)\ .view([b, *([1]*len(z.shape[1:]))]) images = [z] for i in range(sample_steps, 0, -1): t = i / sample_steps t = torch.tensor([t]*b).to(z.device) # §11.1: conditional velocity vc = self.model(z, t, cond) # ← fill this ...
What is "conditional" here?
The model is called with the actual class labels — e.g. [3, 3, 3, ...] for cats. No null tokens. This gives v(z,t,c) — velocity conditioned on the specific class we want to generate.
null_cond is passed in by the caller — it's already prepared (all labels = num_classes). The sample() method doesn't know how many classes exist; the caller handles that.
5 pts
30 / 39

§11.2 Classifier-free guidance (CFG)

vguided = v + s · (vc − v)   where s is the guidance scale
# §11.2: CFG — written inline inside the loop if null_cond is not None: vu = self.model(z, t, null_cond) vc = vu + cfg * (vc - vu) # vc: [B,C,H,W] vu: [B,C,H,W] # cfg: scalar (e.g. 2.0)

Intuition

  • cfg=0: purely unconditional (ignores class)
  • cfg=1: standard conditional (same as vc)
  • cfg>1: amplifies conditional signal beyond vc
CFG: EXTRAPOLATE PAST v_c (s = 2) z_t v_∅ (uncond) v_c (cond) v_guided (s = 2) Δ Δ Δ = v_c − v_∅ · v_guided = v_∅ + s·Δ s = 2 ⇒ walk one extra Δ past v_c (extrapolation)
5 pts
31 / 39

§11.3 Euler step — complete sample()

@torch.no_grad() def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0): b = z.size(0) dt = 1.0 / sample_steps dt = torch.tensor([dt]*b).to(z.device)\ .view([b, *([1]*len(z.shape[1:]))]) images = [z] for i in range(sample_steps, 0, -1): t = i / sample_steps t = torch.tensor([t]*b).to(z.device) # §11.1 vc = self.model(z, t, cond) # §11.2 if null_cond is not None: vu = self.model(z, t, null_cond) vc = vu + cfg * (vc - vu) # §11.3 z = z - dt * vc images.append(z) return images

Step annotations

§11.3 — z = z − dt × vc
Subtract because we walk backwards (t: 1→0). The velocity points data→noise, so subtracting moves us noise→data.
dt shape: [B,1,1,1]
dt is a tensor shaped [B,1,1,1] — required to broadcast correctly with vc [B,C,H,W].
returns images list
images[0] = initial noise, images[-1] = final generated sample. Full trajectory for visualization.
32 / 39

Generation trajectory: noise → digit

Each frame shows zt at a different step during sampling (t goes 1.0→0.0 from left to right). The model progressively de-noises, resolving global structure first, then fine details.

traj_5
traj_0
33 / 39

Effect of sampling steps (cfg=2.0 fixed)

steps = 1
steps_1
blurry, no detail
steps = 5
steps_5
class-correct, blurry
steps = 10
steps_10
good quality
steps = 50
steps_50
sharp, high diversity
Key insight: 1 step already predicts the right class! Flow matching paths are straight so a single coarse step captures class identity. More steps refine sharpness and intra-class diversity. Optimal is ~20–50 steps for this model size.
34 / 39

Effect of CFG scale (steps=50 fixed)

cfg = 0.0
cfg_0
no conditioning — wrong classes!
cfg = 1.0
cfg_1
weak conditioning
cfg = 2.0
cfg_2
good balance ✓
cfg = 10.0
cfg_10
sharp but less diverse
Key insight: cfg=0 proves the null-token training is working — the model truly forgets the class. cfg=10 shows over-guidance: samples look more like "canonical" digit shapes, less natural variation. Sweet spot is cfg=2–3 for MNIST.
20 pts §15
35 / 39

§15 analysis questions — model answers

§15.1 — Steps experiment
Observation. 1 step → blurry single jump that misses fine details. 5–10 steps → correct class identity, acceptable sharpness. ≥20 steps → diminishing returns, mostly cosmetic refinement.
Why. Rectified flow trains nearly straight paths, so Euler integration is almost exact even at low N. Extra steps refine; they don't correct trajectory error.
Trade-off. Pick the smallest N past the elbow of your own curve — usually 20–50 — to halve inference latency without visible quality loss.
§15.2 — CFG scale experiment
Observation. cfg=0 → unconditional, mixed/wrong classes. cfg=1 → weak class signal. cfg=2–3 → sharp, varied, on-class (sweet spot). cfg≥5 → over-saturated, prototype-y, low diversity.
Why. CFG extrapolates v_∅ → v_c by factor s. Larger s amplifies the class-specific direction at the cost of the breadth captured by v_∅.
Trade-off. Quality-vs-diversity dial. Tune on your own ablation grid; the elbow shifts with model size and dataset.
§15.3 — Logit-normal vs uniform t
Observation. Logit-normal training reaches a lower final loss and better samples at the same epoch budget; uniform plateaus higher.
Why. Uniform spends ~60% of steps on the easy regions near t=0 (≈clean) and t=1 (≈pure noise). Logit-normal concentrates mass at t≈0.5 where signal/noise are mixed and gradient signal is largest.
Analogue. Same idea as focal loss — re-weight the budget toward hard examples. SD3 ablation confirms the FID gap.
§15.4 — Continuous (Flow Matching) vs discrete tokens (MaskGIT/GPT)
Flow matching — pros. Single-line MSE loss, no tokenizer or codebook, works natively on any continuous signal (images, audio, video, 3D), state-of-the-art fidelity (SD3, Flux).
Flow matching — cons. 20–50 sequential ODE steps × 2 forward passes for CFG → high inference cost. Hard to mix with discrete modalities (text) without an extra adapter.
Discrete tokens — pros. Reuse LLM tooling and tokenizers, easy to fuse with text in one vocab, MaskGIT-style sampling is fast (~8–12 iterations), training is plain cross-entropy.
Discrete tokens — cons. Quality bottlenecked by the VQ-VAE codebook (reconstruction ceiling). AR sampling is O(N) tokens. Categorical loss scales worse on continuous distributions.
When to pick which. FM for fidelity on continuous data. Discrete tokens for unified text+image models or fast iterative sampling.
V
CIFAR-10
Extension

Scale the same RectifiedFlow + DiT-LLaMA to 3-channel 32×32 colour images. 10 points.

10 pts §16
37 / 39

§16 CIFAR-10: what changes?

MNIST → CIFAR-10 differences

Property MNIST CIFAR-10
Image size 28×28 32×32
Channels 1 (gray) 3 (RGB)
Tokens (patch=2) 14×14=196 16×16=256
Model dim 128, 6 layers 256, 10 layers
Training time ~5 min GPU ~1–2h GPU

The RectifiedFlow training loop stays the same

  • Same forward() method — just with CIFAR data and CIFAR model
  • Same logit-normal timestep sampling
  • Same MSE velocity loss
  • Same Euler sampler at inference
Why a larger model for CIFAR?
3-channel 32×32 images have ~3× more pixels than MNIST. The model needs more capacity to learn the richer distribution.
10 pts §16
38 / 39

CIFAR-10 training code

# 1. Data cifar_transform = transforms.Compose([ transforms.ToTensor(), transforms.RandomHorizontalFlip(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)), ]) cifar = datasets.CIFAR10("./data", train=True, download=True, transform=cifar_transform) loader = DataLoader(cifar, batch_size=256, shuffle=True) # 2. Larger model (3-channel, 32×32) model = DiT_Llama(in_channels=3, input_size=32, dim=256, n_layers=10, n_heads=8, num_classes=10).to(device) rf = RectifiedFlow(model, ln=True) opt = optim.Adam(model.parameters(), lr=5e-4) # 3. Training loop (same as MNIST!) for epoch in range(100): for x, c in loader: x, c = x.to(device), c.to(device) opt.zero_grad() loss, _ = rf.forward(x, c) loss.backward() opt.step()

Only 2 lines change

  • in_channels=3 instead of 1
  • input_size=32 instead of 28
  • Larger dim, more layers for capacity
  • Everything else: identical
10 pts §16
39 / 39

CIFAR-10: generated samples (steps=50, cfg=2.0)

CIFAR-10 generated samples grid (10 classes × 4 samples)