Building a nanoVLM — Solutions
TA session rundown · EPFL COM-304 Foundation Models
Exercise 1
Modality Projector
10 pts
Exercise 2
VQA Collator
15 pts
Exercise 3
VLM Forward Pass
15 pts
Exercise 4
Training Curves
15 pts
Exercise 5
Autoregressive Gen
20 pts
Exercise 6
KV Cache
25 pts
EX 1
Modality Projector
10 pts
Task: Implement
pixel_shuffle() and forward() in models/modality_projector.py — bridge ViT (196 tokens × 768-dim) to the LM (49 tokens × 576-dim).
📄 models/modality_projector.py
def __init__(self, cfg):
...
## TODO
self.proj = ... # ← define linear layer here
def pixel_shuffle(self, x):
bsz, seq, embed_dim = x.size()
seq_root = int(seq**0.5)
assert seq_root**2 == seq
assert seq_root % self.scale_factor == 0
## TODO
height = width = ... # set height and width from seq_root
x = ... # reshape to (B, H, W, E)
h_out = ... # H // scale_factor
w_out = ... # W // scale_factor
x = ... # reshape: (B, h_out, sf, w_out, sf, E)
x = ... # permute: (B, h_out, w_out, sf, sf, E)
x = ... # merge: (B, h_out*w_out, sf*sf*E)
return x
def forward(self, x):
## TODO
x = ... # pixel shuffle
x = ... # linear projection
return x
def __init__(self, cfg):
...
self.proj = nn.Linear(self.input_dim, self.output_dim, bias=False)
def pixel_shuffle(self, x):
bsz, seq, embed_dim = x.size()
seq_root = int(seq ** 0.5)
assert seq_root ** 2 == seq
assert seq_root % self.scale_factor == 0
height = width = seq_root
x = x.view(bsz, height, width, embed_dim)
h_out = height // self.scale_factor
w_out = width // self.scale_factor
x = x.reshape(bsz, h_out, self.scale_factor, w_out, self.scale_factor, embed_dim)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.reshape(bsz, h_out * w_out, embed_dim * self.scale_factor ** 2)
return x
def forward(self, x):
x = self.pixel_shuffle(x)
x = self.proj(x)
return x
Key Insight — Pixel Shuffle
- Treats 196 patch tokens as a 14×14 grid.
- Every 2×2 block of neighbours is folded into 1 token: count ÷4, embedding ×4 — no information lost.
- RoPE must be applied before concatenating with the cached K/V (relevant in Ex 6).
- Final linear maps
3072 → 576to match the LM's hidden dimension.
EX 2
VQA Collator — Tokenisation & Label Masking
15 pts
Task: Complete
VQACollator.__call__ in data/collators.py — stack images, tokenise question+answer with left-padding, build causal labels via a left-shift, and mask padding / question / truncated positions with -100.
📄 data/collators.py
def __call__(self, batch):
images = [item["image"] for item in batch]
texts = [item["text_data"] for item in batch]
answers = [item["answer"] for item in batch]
# Step 1 — Stack images
images = ...
# Step 2 — Build "question + answer" strings
input_sequences = []
for i in range(len(texts)):
...
# Step 3 — Tokenise: left-pad to max_length, right-truncate
encoded_full_sequences = ...
input_ids = ...
attention_mask = ...
# Step 4 — Causal labels (label[t] = input_id[t+1])
labels = ... # clone input_ids
... # shift
... # mask last position
# Step 5 — Per-sample masking
original_lengths = ... # untruncated lengths
for i in range(len(batch)):
question_length = ...
# Case A: truncated → mask entire sample
...
# Case B: left-padded → mask padding + question
first_token_pos = ...
question_end = ...
...
return {"image": images, "input_ids": input_ids,
"attention_mask": attention_mask, "labels": labels}
def __call__(self, batch):
images = [item["image"] for item in batch]
texts = [item["text_data"] for item in batch]
answers = [item["answer"] for item in batch]
# Step 1
images = torch.stack(images)
# Step 2
input_sequences = []
for i in range(len(texts)):
input_sequences.append(f"{texts[i]}{answers[i]}")
# Step 3
encoded_full_sequences = self.tokenizer.batch_encode_plus(
input_sequences,
padding="max_length",
padding_side="left",
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
input_ids = encoded_full_sequences["input_ids"]
attention_mask = encoded_full_sequences["attention_mask"]
# Step 4
labels = input_ids.clone()
labels[:, :-1] = input_ids[:, 1:].clone()
labels[:, -1] = -100
# Step 5
original_lengths = [len(self.tokenizer.encode(seq)) for seq in input_sequences]
for i in range(len(batch)):
question_length = len(self.tokenizer.encode(texts[i], add_special_tokens=False))
if original_lengths[i] > self.max_length: # Case A: truncated
labels[i, :] = -100
continue
first_token_pos = attention_mask[i].nonzero(as_tuple=True)[0][0].item()
question_end = first_token_pos + question_length - 1 # -1 for shift
labels[i, :question_end] = -100
return {"image": images, "input_ids": input_ids,
"attention_mask": attention_mask, "labels": labels}
Key Insight — Three Token Categories
- Padding — left-padded to
max_length; not real content →-100. - Question tokens — given as context; model should not predict them →
-100. - Truncated samples — answer is cut off; supervising a partial answer adds noise → mask whole sample.
Shift rule: labels[:, :-1] = input_ids[:, 1:] — so label[t] is the token to predict at position t.
EX 3
VLM Forward Pass
15 pts
Task: Implement
VisionLanguageModel.forward() in models/vision_language_model.py — encode image, embed text, concatenate, extend attention mask, run LM, compute cross-entropy loss on answer tokens only.
📄 models/vision_language_model.py —
VisionLanguageModel.forward()def forward(self, input_ids, image, attention_mask=None, targets=None):
# Step 1: Image embeddings (ViT → Projector)
image_embeds = ...
# Step 2: Text token embeddings
text_embeds = ...
# Step 3: Concatenate along sequence dim
combined_embeds = ...
# Step 4: Extend attention mask for image tokens
if attention_mask is not None:
image_attention = ... # all-ones for image prefix
attention_mask = ... # concat [image_attn | text_attn]
# Step 5: LLM forward
output_token_embeddings = ...
loss = None
if targets is not None:
logits = ... # Step 6: project to vocab via decoder head
logits = ... # Step 7: drop image prefix from logits
loss = ... # Step 8: cross-entropy, ignore_index=-100
return logits, loss
def forward(self, input_ids, image, attention_mask=None, targets=None):
image_embeds = self.vision_encoder(image) # (B, 196, 768)
image_embeds = self.MP(image_embeds) # (B, 49, 576)
text_embeds = self.decoder.token_embedding(input_ids) # (B, T, 576)
combined_embeds = torch.cat((image_embeds, text_embeds), dim=1) # (B, 49+T, 576)
if attention_mask is not None:
B, img_len = image_embeds.size(0), image_embeds.size(1)
image_attention = torch.ones((B, img_len),
device=attention_mask.device,
dtype=attention_mask.dtype)
attention_mask = torch.cat((image_attention, attention_mask), dim=1)
output_token_embeddings = self.decoder(combined_embeds, attention_mask)
loss = None
if targets is not None:
logits = self.decoder.head(output_token_embeddings) # (B, 49+T, vocab)
logits = logits[:, image_embeds.size(1):, :] # (B, T, vocab)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
targets.reshape(-1), ignore_index=-100)
return logits, loss
Key Insight
- Image tokens are prepended to text tokens — the LM never distinguishes between them.
- Loss is computed only on the text part of logits:
logits[:, img_len:, :]. - Answer tokens already have label
-100masked out by the collator (Ex 2).
EX 4
Training Curves Screenshot
15 pts
Task: Run training on 2× L40S GPUs, monitor W&B, and attach a screenshot of your own loss curves (training loss + val loss + MMStar accuracy) to the notebook.
Launch command:
OMP_NUM_THREADS=1 torchrun --nproc_per_node=2 train.py
Expected outcome after ~30k steps: training loss ↓, val loss ↓, MMStar ~25–28%.
Reference curves (students should produce something similar):
Reference — nanoVLM-222M training & validation loss + MMStar accuracy over 30k steps.
EX 5
Autoregressive Generation
20 pts
Task: Implement
VisionLanguageModel.generate() in models/vision_language_model.py — encode image + text, then iteratively sample the next token until EOS (id=2) or max_new_tokens.
📄 models/vision_language_model.py —
generate()@torch.no_grad()
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5):
# (i) Encode image + embed text + concatenate + extend mask
image_embd = ...
image_embd = ...
token_embd = ...
combined_embed = ...
if attention_mask is not None:
image_attention_mask = ...
attention_mask = ...
outputs = combined_embed
generated_tokens = torch.zeros((batch_size, max_new_tokens), ...)
for i in range(...):
model_out = ... # full forward pass on growing sequence
last_token_logits = ... # last position logits only
if ...: # apply LM head if in embedding mode
last_token_logits = ...
probs = ... # softmax
next_token = ... # multinomial sample
generated_tokens[:, i] = next_token.squeeze(-1)
generated_embed = ... # embed new token
outputs = ... # append to sequence
if attention_mask is not None:
attention_mask = ... # extend mask by 1
if ...: # stop at EOS (token id = 2)
break
return ...
@torch.no_grad()
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5):
image_embd = self.vision_encoder(image) # (B, 196, 768)
image_embd = self.MP(image_embd) # (B, 49, 576)
token_embd = self.decoder.token_embedding(input_ids)
combined_embed = torch.cat((image_embd, token_embd), dim=1)
batch_size, img_seq_len = image_embd.size(0), image_embd.size(1)
if attention_mask is not None:
image_attention_mask = torch.ones((batch_size, img_seq_len),
device=attention_mask.device, dtype=attention_mask.dtype)
attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1)
outputs = combined_embed
generated_tokens = torch.zeros((batch_size, max_new_tokens),
device=input_ids.device, dtype=input_ids.dtype)
for i in range(max_new_tokens):
model_out = self.decoder(outputs, attention_mask)
last_token_logits = model_out[:, -1, :]
if not self.decoder.lm_use_tokens:
last_token_logits = self.decoder.head(last_token_logits)
probs = torch.softmax(last_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(-1)
generated_embed = self.decoder.token_embedding(next_token)
outputs = torch.cat((outputs, generated_embed), dim=1)
if attention_mask is not None:
attention_mask = torch.cat(
(attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)), dim=1)
if next_token.item() == 2: # EOS
break
return generated_tokens
Key Insight — Naive Generation
- At each step the full growing sequence is re-processed — cost grows quadratically. Exercise 6 fixes this.
- Only the last position's logit matters for predicting the next token.
- Sampling:
torch.multinomial(stochastic) — responses vary across runs, unlike greedy argmax. - Stop condition: EOS token id =
2, or loop reachesmax_new_tokens.
EX 6
KV Cache for Efficient Inference
25 pts
Task: Implement KV caching across 4 methods in
models/language_model.py and models/vision_language_model.py. Replace O(T·N²) naive generation with O(N²) prefill + O(T·N) decode.
Part 1 of 4
VisionLanguageModel.generate_with_kv_cache()
📄 models/vision_language_model.py
@torch.no_grad()
def generate_with_kv_cache(self, input_ids, image, ...):
# Step 1: Build combined embeddings (same as generate())
image_embd = ...
token_embd = ...
combined_embd = ...
# Step 2: PREFILL — one forward pass, collect KV cache
model_out, past_key_values = ...
# Step 3: First generated token
last_logits = ...
if not self.decoder.lm_use_tokens:
last_logits = ...
probs = ...
next_token = ...
generated_tokens[:, 0] = next_token.squeeze(-1)
# Step 4: DECODE LOOP
for i in range(1, max_new_tokens):
next_embd = ... # embed only the last generated token (B, 1, D)
model_out, past_key_values = ... # 1-token forward, reuse cache
last_logits = ...
...
next_token = ...
generated_tokens[:, i] = next_token.squeeze(-1)
if ...: # EOS check
break
return generated_tokens
@torch.no_grad()
def generate_with_kv_cache(self, input_ids, image, attention_mask=None, max_new_tokens=5):
image_embd = self.vision_encoder(image)
image_embd = self.MP(image_embd)
token_embd = self.decoder.token_embedding(input_ids)
combined_embd = torch.cat((image_embd, token_embd), dim=1)
batch_size = image_embd.size(0)
# PREFILL: run full prompt once, build entire KV cache
model_out, past_key_values = self.decoder.forward_kv(combined_embd)
last_logits = model_out[:, -1, :]
if not self.decoder.lm_use_tokens:
last_logits = self.decoder.head(last_logits)
probs = torch.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens = torch.zeros((batch_size, max_new_tokens),
device=input_ids.device, dtype=input_ids.dtype)
generated_tokens[:, 0] = next_token.squeeze(-1)
# DECODE LOOP: 1 token per step, reuse cached K/V
for i in range(1, max_new_tokens):
next_embd = self.decoder.token_embedding(next_token) # (B, 1, D)
model_out, past_key_values = self.decoder.forward_kv(
next_embd, past_key_values=past_key_values
)
last_logits = model_out[:, -1, :]
if not self.decoder.lm_use_tokens:
last_logits = self.decoder.head(last_logits)
probs = torch.softmax(last_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(-1)
if next_token.item() == 2:
break
return generated_tokens
Part 2 of 4
LanguageModel.forward_kv()
📄 models/language_model.py
def forward_kv(self, x, past_key_values=None):
if self.lm_use_tokens:
x = self.token_embedding(x)
B, T, _ = x.size()
# Fix position IDs: new tokens continue from past_length, not 0
past_length = ... # 0 if no cache, else past_key_values[0][0].size(2)
position_ids = ... # arange(past_length, past_length + T)
cos, sin = self.rotary_embd(position_ids)
present_key_values = []
for i, block in enumerate(self.blocks):
past_kv = ... # this layer's cache (None on first call)
x, present_kv = ... # call block.forward_kv
... # append present_kv
x = self.norm(x)
if self.lm_use_tokens:
x = self.head(x)
return ... # (hidden_states, present_key_values)
def forward_kv(self, x, past_key_values=None):
if self.lm_use_tokens:
x = self.token_embedding(x)
B, T, _ = x.size()
past_length = 0
if past_key_values is not None:
past_length = past_key_values[0][0].size(2) # layer-0 key length
position_ids = torch.arange(past_length, past_length + T,
device=x.device).unsqueeze(0).expand(B, -1)
cos, sin = self.rotary_embd(position_ids)
present_key_values = []
for i, block in enumerate(self.blocks):
past_kv = past_key_values[i] if past_key_values is not None else None
x, present_kv = block.forward_kv(x, cos, sin, past_kv)
present_key_values.append(present_kv)
x = self.norm(x)
if self.lm_use_tokens:
x = self.head(x)
return x, present_key_values
Part 3 of 4
LanguageModelBlock.forward_kv()
📄 models/language_model.py
def forward_kv(self, x, cos, sin, past_key_value=None):
res = x
x = self.norm1(x)
x, present_key_value = ... # call self.attn.forward_kv(...)
x = res + x
res = x
x = self.norm2(x)
x = self.mlp(x)
x = res + x
return ... # (hidden_states, present_key_value)
def forward_kv(self, x, cos, sin, past_key_value=None):
res = x
x = self.norm1(x)
x, present_key_value = self.attn.forward_kv(x, cos, sin, past_key_value)
x = res + x
res = x
x = self.norm2(x)
x = self.mlp(x)
x = res + x
return x, present_key_value
Part 4 of 4
LanguageModelGroupedQueryAttention.forward_kv()
📄 models/language_model.py
def forward_kv(self, x, cos, sin, past_key_value=None):
B, T, C = x.size()
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rotary_pos_embd(q, k, cos, sin) # RoPE applied BEFORE concat!
# Step 1: Prepend cached K/V if this is a decode step
if past_key_value is not None:
past_k, past_v = ...
k = ...
v = ...
# Step 2: Save present (full) cache
present_key_value = ...
k = k.repeat_interleave(self.n_kv_groups, dim=1)
v = v.repeat_interleave(self.n_kv_groups, dim=1)
# Step 3: Attend — causal during prefill only
is_decode = ...
y = F.scaled_dot_product_attention(q, k, v, is_causal=...)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.out_proj(y)
# Step 4: Return output + cache
return ...
def forward_kv(self, x, cos, sin, past_key_value=None):
B, T, C = x.size()
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rotary_pos_embd(q, k, cos, sin) # RoPE applied BEFORE concat!
# Prepend cached K/V
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=2) # (B, n_kv_heads, past_len+T, head_dim)
v = torch.cat([past_v, v], dim=2)
# Store full K,V for the next step
present_key_value = (k, v)
k = k.repeat_interleave(self.n_kv_groups, dim=1)
v = v.repeat_interleave(self.n_kv_groups, dim=1)
is_decode = past_key_value is not None
y = F.scaled_dot_product_attention(q, k, v,
is_causal=not is_decode) # no causal mask during decode
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.out_proj(y)
return y, present_key_value
Key Insight — Two-Phase Caching
- PREFILL: run full prompt once → builds K/V cache for every layer.
- DECODE: feed only 1 new token per step → prepend cached K/V, compute attention over full history at O(N+t) cost.
- Critical: RoPE is applied before concat, and
position_idsmust start frompast_length— not 0. - Causal mask: needed during prefill (T_q > 1), not during decode (T_q = 1 trivially attends all past tokens).