↑↓ or ←→ to navigate
TA walkthrough of Part 4 — Rectified Flow Matching, DiT-LLaMA architecture, and class-conditional image generation on MNIST & CIFAR-10.
Agenda overview: "Here's the roadmap for today. We'll go section by section — first a recap of the project structure, then the DiT model exercises, then the training loop, and finally the sampling with classifier-free guidance."
Point out that §6.6 is a sanity check — it only passes if ALL of §6.1 through §6.5 are correct. So if it fails, don't debug 6.6 directly, go back and check each earlier exercise.
nano4M/
└── nanofm/
├── modeling/
│ └── dit.py ← DiT_Llama, TransformerBlock, FinalLayer, modulate()
├── models/
│ └── rectified_flow.py ← RectifiedFlow.forward() + sample() — your target
└── notebooks/
└── COM304_FM_part4_nanoFlowMatching.ipynb
RectifiedFlowDiT_LlamaRectified flow, straight-line interpolation, velocity fields, and why it beats DDPM.
Section header — Part I: "Let's start with the overall architecture. Before we dive into code, I want to make sure everyone has a mental picture of the full pipeline — how does an image go in, and how does a velocity field come out?"
Each frame shows zt = (1−t)·x0 + t·x1 for one MNIST digit at increasing t. At t=0 the digit is fully visible; at t=1 it becomes indistinguishable Gaussian noise.
Logit-normal timestep sampling: "This slide explains why we sample t from a logit-normal distribution instead of uniform. The key intuition is that the model finds the midpoint t≈0.5 much harder than the endpoints."
At t=0, the image is nearly clean — easy. At t=1, it's pure noise — the model just predicts a mean. But at t=0.5, it's half signal, half noise — this is where the model does the hard work. Logit-normal sampling concentrates more training examples in this hard region, which improves final sample quality.
DiT-LLaMA overview: "The DiT-LLaMA model combines two ideas: DiT (Diffusion Transformer) and LLaMA-style transformer blocks. The key ingredients are: a CNN frontend for local features, patch tokenization, transformer blocks conditioned on time+class via AdaLN, and a final projection back to image space."
Don't get intimidated by the name. It's a standard transformer, with two twists: rotary positional encoding from LLaMA, and Adaptive Layer Norm conditioning from the DiT paper. Both are covered in the exercises.
def forward(self, x, cond):
b = x.size(0)
# logit-normal sampling
if self.ln:
nt = torch.randn((b,)).to(x.device)
t = torch.sigmoid(nt) # [B]
else:
t = torch.rand((b,)).to(x.device)
# reshape for broadcasting with [B,C,H,W]
texp = t.view([b, *([1]*len(x.shape[1:]))])
...
Both histograms show 10,000 sampled timesteps t ∈ [0,1]. The left is flat (uniform); the right peaks near t≈0.5 — exactly the hardest region for the model.
CIFAR-10 context: "For this project we're using CIFAR-10 — 32×32 RGB images, 10 classes. The model parameters are sized accordingly: patch_size=2 gives us 16×16=256 tokens per image, RGB means C=3 output channels."
The FinalLayer output per token is patch_size²×C = 4×3 = 12 values. After unpatchify, we get back to [B, 3, 32, 32].
Diffusion Transformer with adaptive layer norm, RoPE, SwiGLU, and class conditioning via CFG dropout.
Section header — DiT Architecture: "Now let's get into the actual exercises. We'll go one by one through §6.1 to §6.5. I'll show you the correct solution and explain the key design choices."
def modulate(x, shift, scale):
# x: [B, N, dim] — token sequence
# scale: [B, dim] — from adaLN MLP
# shift: [B, dim] — from adaLN MLP
# unsqueeze(1): [B,dim] → [B,1,dim] for broadcast
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def token_drop(self, labels, force_drop_ids=None):
# During training: randomly replace class label
# with "null" token (= num_classes) with prob p_uncond
if force_drop_ids is None:
drop_ids = torch.rand(
labels.shape[0], device=labels.device
) < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(
drop_ids,
self.num_classes, # null token index
labels
)
return labels
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2 # e.g. 128
freqs = torch.exp(
-math.log(max_period) *
torch.arange(0, half) / half
).to(t.device) # [128] geometric
args = t[:, None] * freqs[None] # [B,128]
embedding = torch.cat(
[torch.cos(args), torch.sin(args)],
dim=-1
) # [B, 256]
def forward(self, t):
t_freq = self.timestep_embedding(
t, self.frequency_embedding_size)
return self.mlp(t_freq) # [B, D]
sin(t · ωₖ) across embedding dims — each curve = one t value
TimestepEmbedder: "Before we talk about how conditioning enters the model, let me explain how a scalar timestep t becomes a useful vector."
If we just fed t as a single number into the MLP, the network would have very little to work with. Instead, we first expand t into 256 values using sinusoidal functions at different frequencies — low frequencies distinguish coarse time differences, high frequencies distinguish fine ones. The waveform chart shows this clearly: on the left side (low-frequency dims), different t values produce very different heights. On the right (high-frequency dims), all values collapse to zero.
The MLP then maps these 256 features to the model dimension D. The same approach is used for the class label. The two embeddings are simply added: adaln_input = t_emb + y_emb.
Attention is permutation-invariant — without a position signal, patch (row 3, col 5) looks identical to patch (0, 0). RoPE injects position by rotating the Q and K vectors by an angle proportional to where the patch sits.
self.freqs_cis precomputes (cos θ, sin θ) for every (row, col) at __init__ — a buffer, not a parameterapply_rotary_emb(q, k, freqs_cis) applies the rotation inside each attention blockfreqs_cis to x.device before the layer loop — otherwise device-mismatch crash on GPU. Bug #1 on slide 21.
Opening (one sentence): "Attention by itself doesn't know where each token lives. The original transformer added a learned position vector; RoPE does something cleverer — it rotates the Q and K vectors by an angle proportional to the patch position."
Walk the three-step strip left-to-right: ① pick an angle θ = p · ωk — position times a fixed frequency, zero learnable params. ② apply a standard 2D rotation to every adjacent pair of dims in Q and K, inside attention, before the dot product. ③ when you compute Q·KT, the absolute angles cancel and only the difference pq−pk survives — attention sees relative position for free, and the model extrapolates to sequence lengths it never saw in training.
On the diagram: point out that the same (q2k, q2k+1) pair is drawn at four different positions p=0,1,2,3. Each step in p adds ωk to the rotation angle. Different frequency bands ωk rotate at different speeds — that's how a single head encodes many scales of position.
2D images + code: we split the head dim in half — one half encodes row index, the other column index. freqs_cis is a precomputed (cos, sin) buffer; apply_rotary_emb applies the rotation inside attention. Your only job in §6.5 is to move freqs_cis to x.device before the layer loop — forgetting that is bug #1 on slide 21.
Peebles & Xiao, Scalable Diffusion Models with Transformers (2023)
Slide intro: "The DiT paper tried three different ways to tell the transformer 'what class and timestep are we at?' Let's look at all three."
In-Context: "The simplest idea — just prepend the class and timestep tokens to the patch sequence and run normal self-attention. No extra code, no extra parameters. The problem is those condition tokens have to compete for attention with all the patch tokens, and as you go deeper the signal gets diluted."
Cross-Attention: "This is what Stable Diffusion uses for text conditioning — the patches attend to the text tokens via cross-attention. It's very expressive but you're adding a whole extra attention block at every layer. For a text prompt that's worth it. For just a class label and a timestep? That's overkill — too many extra parameters."
adaLN-Zero: "This is what DiT-LLaMA uses, and what you're implementing. The MLP takes the combined timestep+class embedding and outputs 6 vectors — shift, scale, and gate for both attention and FFN. That's only 6×D extra parameters per block, tiny compared to cross-attention. The zero-init is the secret weapon: at the start of training all gates are zero, so the model is just a stack of skip connections. Training gradually opens them. This gives very stable optimization, and the paper showed it beats the other two on FID."
def forward(self, x, freqs_cis, adaln_input=None):
if adaln_input is not None:
# Step 1: get 6 params — dim=1, not dim=-1!
shift_msa, scale_msa, gate_msa, \
shift_mlp, scale_mlp, gate_mlp = (
self.adaLN_modulation(adaln_input)
.chunk(6, dim=1)
)
# Step 2: AdaLN-modulated attention
x = x + gate_msa.unsqueeze(1) * self.attention(
modulate(self.attention_norm(x),
shift_msa, scale_msa), # NO .unsqueeze here
freqs_cis
)
# Step 3: AdaLN-modulated FFN
x = x + gate_mlp.unsqueeze(1) * self.feed_forward(
modulate(self.ffn_norm(x),
shift_mlp, scale_mlp) # NO .unsqueeze here
)
else:
x = x + self.attention(self.attention_norm(x), freqs_cis)
x = x + self.feed_forward(self.ffn_norm(x))
return x
.chunk(6, dim=1) not dim=-1. (2) Pass shift/scale directly to modulate() — it adds .unsqueeze(1) internally.
def forward(self, x, c):
# dim=1, no .unsqueeze — modulate does it
shift, scale = \
self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
# Project to patch output dim
x = self.linear(x)
return x
# x shape: [B, N, patch_size² × out_channels]
def forward(self, x, t, y):
# Move RoPE to correct device first
self.freqs_cis = self.freqs_cis.to(x.device)
# 1. Shallow CNN features, then patch tokens
x = self.init_conv_seq(x) # [B, dim//2, H, W]
x = self.patchify(x) # [B, N, patch²×dim//2]
x = self.x_embedder(x) # [B, N, dim]
# 2. Conditioning vector: t + y → adaln_input
t = self.t_embedder(t) # [B, dim]
y = self.y_embedder( # [B, dim]
y, self.training) # ← token_drop!
adaln_input = t.to(x.dtype) + y.to(x.dtype)
# 3. Transformer stack
for layer in self.layers:
x = layer(x, self.freqs_cis[:x.size(1)],
adaln_input=adaln_input)
# 4. Project + reshape
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x) # [B,C,H,W]
return x
init_conv_seqx_embedderTrue vs self.trainingc vs adaln_inputImplement the RectifiedFlow training loop: sample t, interpolate, compute velocity target, minimize MSE.
Section header — Training the Model: "Now we move to the RectifiedFlow class. This is where the actual training loss is computed. The three exercises here are conceptually simple — the tricky part is getting the shapes exactly right."
def forward(self, x, cond):
b = x.size(0)
# §7.1 — Sample one t per image in batch
if self.ln:
# Logit-normal: concentrate near t=0.5
nt = torch.randn((b,)).to(x.device)
t = torch.sigmoid(nt) # shape [B]
else:
# Uniform fallback
t = torch.rand((b,)).to(x.device)
# texp for broadcasting with image dims
texp = t.view([b, *([1] * len(x.shape[1:]))]) # [B,1,1,1]
t.view([b, *([1]*len(x.shape[1:]))]) — generalizes beyond 2D images # §7.2 — Sample noise and interpolate
z1 = torch.randn_like(x) # Gaussian noise, same shape as x
# Linear interpolation: z_t = (1-t)*x + t*z1
zt = (1 - texp) * x + texp * z1
# zt shape: [B, C, H, W]
# zt is the model input — noisy image at time t
torch.randn_like(x)?z1 (not x1) matching the convention: z₀=data, z₁=noise.At t=0: zt = x (pure data). At t=1: zt = z1 (pure noise). At t=0.5: equal mix. This matches the interpolation strip shown earlier.
x is the clean data (z₀), z1 is the sampled noise (z₁). The interpolation z_t moves from x toward z1 as t increases.
# §7.3a — Model forward: predict velocity
vtheta = self.model(zt, t, cond)
# target is z1 - x (direction: data → noise)
# §7.3b — Batchwise MSE over C×H×W dims
batchwise_mse = (
(z1 - x - vtheta) ** 2
).mean(dim=list(range(1, len(x.shape))))
# batchwise_mse shape: [B]
# §7.3c — Logging and return
tlist = batchwise_mse.detach().cpu()\
.reshape(-1).tolist()
ttloss = [(tv, tl) for tv, tl in zip(t, tlist)]
return batchwise_mse.mean(), ttloss
ODE integration with Euler steps, classifier-free guidance, and how sampling steps + CFG scale affect quality.
Section header — Sampling: "The last set of exercises is about inference. Given a trained model, how do we generate images? We start from pure noise and run the Euler ODE solver backwards."
# Start from pure noise at t=1
z = torch.randn_like(x_shape)
dt = 1.0 / num_steps
for i in range(num_steps, 0, -1):
t = i / num_steps # t: 1.0 → dt
t_tensor = torch.full((B,), t)
# Predict velocity at current z, t
v = model(z, t_tensor, cond)
# Step backwards (noise→data)
z = z - dt * v
# z is now ~ p(data)
@torch.no_grad()
def sample(self, z, cond, null_cond=None,
sample_steps=50, cfg=2.0):
b = z.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt]*b).to(z.device)\
.view([b, *([1]*len(z.shape[1:]))])
images = [z]
for i in range(sample_steps, 0, -1):
t = i / sample_steps
t = torch.tensor([t]*b).to(z.device)
# §11.1: conditional velocity
vc = self.model(z, t, cond) # ← fill this
...
# §11.2: CFG — written inline inside the loop
if null_cond is not None:
vu = self.model(z, t, null_cond)
vc = vu + cfg * (vc - vu)
# vc: [B,C,H,W] vu: [B,C,H,W]
# cfg: scalar (e.g. 2.0)
@torch.no_grad()
def sample(self, z, cond, null_cond=None,
sample_steps=50, cfg=2.0):
b = z.size(0)
dt = 1.0 / sample_steps
dt = torch.tensor([dt]*b).to(z.device)\
.view([b, *([1]*len(z.shape[1:]))])
images = [z]
for i in range(sample_steps, 0, -1):
t = i / sample_steps
t = torch.tensor([t]*b).to(z.device)
# §11.1
vc = self.model(z, t, cond)
# §11.2
if null_cond is not None:
vu = self.model(z, t, null_cond)
vc = vu + cfg * (vc - vu)
# §11.3
z = z - dt * vc
images.append(z)
return images
Each frame shows zt at a different step during sampling (t goes 1.0→0.0 from left to right). The model progressively de-noises, resolving global structure first, then fine details.
CFG ablation: "This side-by-side shows the effect of the CFG scale. Low scale: more diverse but less class-specific. High scale: very sharp class identity but some artifacts start appearing. The sweet spot depends on the application — for visual quality, scale 2–4 works well for CIFAR."
Section header — Evaluation: "The final part covers how we evaluate generative models quantitatively. This is important background for understanding the project grading criteria."
FID score: "FID — Fréchet Inception Distance — is the standard metric for image generation quality. It compares the distribution of generated images to real images in a feature space extracted by InceptionNet."
Lower FID is better. A FID of 0 would mean generated images are indistinguishable from real ones. State-of-the-art models on CIFAR achieve FID around 2–5. Your trained model should get below 20 for a good score. The key insight: FID measures distribution quality, not individual image quality — a model that always generates the same perfect image would have high FID because it lacks diversity.
Scale the same RectifiedFlow + DiT-LLaMA to 3-channel 32×32 colour images. 10 points.
IS score and precision/recall: "IS — Inception Score — measures both image quality (images should look like something specific) and diversity (different images should look like different things). Precision measures quality, recall measures diversity — you can trade one for the other with the CFG scale."
| Property | MNIST | CIFAR-10 |
| Image size | 28×28 | 32×32 |
| Channels | 1 (gray) | 3 (RGB) |
| Tokens (patch=2) | 14×14=196 | 16×16=256 |
| Model dim | 128, 6 layers | 256, 10 layers |
| Training time | ~5 min GPU | ~1–2h GPU |
# 1. Data
cifar_transform = transforms.Compose([
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
])
cifar = datasets.CIFAR10("./data", train=True, download=True,
transform=cifar_transform)
loader = DataLoader(cifar, batch_size=256, shuffle=True)
# 2. Larger model (3-channel, 32×32)
model = DiT_Llama(in_channels=3, input_size=32,
dim=256, n_layers=10, n_heads=8,
num_classes=10).to(device)
rf = RectifiedFlow(model, ln=True)
opt = optim.Adam(model.parameters(), lr=5e-4)
# 3. Training loop (same as MNIST!)
for epoch in range(100):
for x, c in loader:
x, c = x.to(device), c.to(device)
opt.zero_grad()
loss, _ = rf.forward(x, c)
loss.backward()
opt.step()
in_channels=3 instead of 1input_size=32 instead of 28
Closing on CIFAR: "This is roughly what students should turn in — class identity is correct, within-class variety is preserved, and the fidelity ceiling matches the 100-epoch training budget. If their grid looks worse, point them at the §15 ablations: steps, cfg scale, and logit-normal vs uniform."