nanoVLM — Exercise Solutions 100 pts total

Building a nanoVLM — Solutions

TA session rundown · EPFL COM-304 Foundation Models

Exercise 1
Modality Projector
10 pts
Exercise 2
VQA Collator
15 pts
Exercise 3
VLM Forward Pass
15 pts
Exercise 4
Training Curves
15 pts
Exercise 5
Autoregressive Gen
20 pts
Exercise 6
KV Cache
25 pts
EX 1

Modality Projector

10 pts
Task: Implement pixel_shuffle() and forward() in models/modality_projector.py — bridge ViT (196 tokens × 768-dim) to the LM (49 tokens × 576-dim).
📄 models/modality_projector.py
def __init__(self, cfg):
    ...
    ## TODO
    self.proj = ...           # ← define linear layer here

def pixel_shuffle(self, x):
    bsz, seq, embed_dim = x.size()
    seq_root = int(seq**0.5)
    assert seq_root**2 == seq
    assert seq_root % self.scale_factor == 0

    ## TODO
    height = width = ...      # set height and width from seq_root
    x = ...                   # reshape to (B, H, W, E)
    h_out = ...               # H // scale_factor
    w_out = ...               # W // scale_factor

    x = ...                   # reshape: (B, h_out, sf, w_out, sf, E)
    x = ...                   # permute: (B, h_out, w_out, sf, sf, E)
    x = ...                   # merge:   (B, h_out*w_out, sf*sf*E)
    return x

def forward(self, x):
    ## TODO
    x = ...                   # pixel shuffle
    x = ...                   # linear projection
    return x
def __init__(self, cfg):
    ...
    self.proj = nn.Linear(self.input_dim, self.output_dim, bias=False)

def pixel_shuffle(self, x):
    bsz, seq, embed_dim = x.size()
    seq_root = int(seq ** 0.5)
    assert seq_root ** 2 == seq
    assert seq_root % self.scale_factor == 0

    height = width = seq_root
    x = x.view(bsz, height, width, embed_dim)
    h_out = height // self.scale_factor
    w_out = width // self.scale_factor

    x = x.reshape(bsz, h_out, self.scale_factor, w_out, self.scale_factor, embed_dim)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.reshape(bsz, h_out * w_out, embed_dim * self.scale_factor ** 2)
    return x

def forward(self, x):
    x = self.pixel_shuffle(x)
    x = self.proj(x)
    return x
Key Insight — Pixel Shuffle
  • Treats 196 patch tokens as a 14×14 grid.
  • Every 2×2 block of neighbours is folded into 1 token: count ÷4, embedding ×4 — no information lost.
  • RoPE must be applied before concatenating with the cached K/V (relevant in Ex 6).
  • Final linear maps 3072 → 576 to match the LM's hidden dimension.
EX 2

VQA Collator — Tokenisation & Label Masking

15 pts
Task: Complete VQACollator.__call__ in data/collators.py — stack images, tokenise question+answer with left-padding, build causal labels via a left-shift, and mask padding / question / truncated positions with -100.
📄 data/collators.py
def __call__(self, batch):
    images  = [item["image"]     for item in batch]
    texts   = [item["text_data"] for item in batch]
    answers = [item["answer"]    for item in batch]

    # Step 1 — Stack images
    images = ...

    # Step 2 — Build "question + answer" strings
    input_sequences = []
    for i in range(len(texts)):
        ...

    # Step 3 — Tokenise: left-pad to max_length, right-truncate
    encoded_full_sequences = ...
    input_ids      = ...
    attention_mask = ...

    # Step 4 — Causal labels (label[t] = input_id[t+1])
    labels =  ...           # clone input_ids
    ...                     # shift
    ...                     # mask last position

    # Step 5 — Per-sample masking
    original_lengths = ...  # untruncated lengths

    for i in range(len(batch)):
        question_length = ...
        # Case A: truncated → mask entire sample
        ...
        # Case B: left-padded → mask padding + question
        first_token_pos = ...
        question_end    = ...
        ...

    return {"image": images, "input_ids": input_ids,
            "attention_mask": attention_mask, "labels": labels}
def __call__(self, batch):
    images  = [item["image"]     for item in batch]
    texts   = [item["text_data"] for item in batch]
    answers = [item["answer"]    for item in batch]

    # Step 1
    images = torch.stack(images)

    # Step 2
    input_sequences = []
    for i in range(len(texts)):
        input_sequences.append(f"{texts[i]}{answers[i]}")

    # Step 3
    encoded_full_sequences = self.tokenizer.batch_encode_plus(
        input_sequences,
        padding="max_length",
        padding_side="left",
        return_tensors="pt",
        truncation=True,
        max_length=self.max_length,
    )
    input_ids      = encoded_full_sequences["input_ids"]
    attention_mask = encoded_full_sequences["attention_mask"]

    # Step 4
    labels = input_ids.clone()
    labels[:, :-1] = input_ids[:, 1:].clone()
    labels[:, -1]  = -100

    # Step 5
    original_lengths = [len(self.tokenizer.encode(seq)) for seq in input_sequences]

    for i in range(len(batch)):
        question_length = len(self.tokenizer.encode(texts[i], add_special_tokens=False))

        if original_lengths[i] > self.max_length:   # Case A: truncated
            labels[i, :] = -100
            continue

        first_token_pos = attention_mask[i].nonzero(as_tuple=True)[0][0].item()
        question_end    = first_token_pos + question_length - 1   # -1 for shift
        labels[i, :question_end] = -100

    return {"image": images, "input_ids": input_ids,
            "attention_mask": attention_mask, "labels": labels}
Key Insight — Three Token Categories
  • Padding — left-padded to max_length; not real content → -100.
  • Question tokens — given as context; model should not predict them → -100.
  • Truncated samples — answer is cut off; supervising a partial answer adds noise → mask whole sample.

Shift rule: labels[:, :-1] = input_ids[:, 1:] — so label[t] is the token to predict at position t.

EX 3

VLM Forward Pass

15 pts
Task: Implement VisionLanguageModel.forward() in models/vision_language_model.py — encode image, embed text, concatenate, extend attention mask, run LM, compute cross-entropy loss on answer tokens only.
📄 models/vision_language_model.py — VisionLanguageModel.forward()
def forward(self, input_ids, image, attention_mask=None, targets=None):

    # Step 1: Image embeddings (ViT → Projector)
    image_embeds = ...

    # Step 2: Text token embeddings
    text_embeds = ...

    # Step 3: Concatenate along sequence dim
    combined_embeds = ...

    # Step 4: Extend attention mask for image tokens
    if attention_mask is not None:
        image_attention = ...   # all-ones for image prefix
        attention_mask  = ...   # concat [image_attn | text_attn]

    # Step 5: LLM forward
    output_token_embeddings = ...

    loss = None
    if targets is not None:
        logits = ...    # Step 6: project to vocab via decoder head
        logits = ...    # Step 7: drop image prefix from logits
        loss   = ...    # Step 8: cross-entropy, ignore_index=-100

    return logits, loss
def forward(self, input_ids, image, attention_mask=None, targets=None):

    image_embeds = self.vision_encoder(image)            # (B, 196, 768)
    image_embeds = self.MP(image_embeds)                 # (B, 49,  576)

    text_embeds = self.decoder.token_embedding(input_ids)  # (B, T, 576)

    combined_embeds = torch.cat((image_embeds, text_embeds), dim=1)  # (B, 49+T, 576)

    if attention_mask is not None:
        B, img_len = image_embeds.size(0), image_embeds.size(1)
        image_attention = torch.ones((B, img_len),
                                      device=attention_mask.device,
                                      dtype=attention_mask.dtype)
        attention_mask = torch.cat((image_attention, attention_mask), dim=1)

    output_token_embeddings = self.decoder(combined_embeds, attention_mask)

    loss = None
    if targets is not None:
        logits = self.decoder.head(output_token_embeddings)   # (B, 49+T, vocab)
        logits = logits[:, image_embeds.size(1):, :]          # (B, T,    vocab)
        loss   = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
                                  targets.reshape(-1), ignore_index=-100)

    return logits, loss
Key Insight
  • Image tokens are prepended to text tokens — the LM never distinguishes between them.
  • Loss is computed only on the text part of logits: logits[:, img_len:, :].
  • Answer tokens already have label -100 masked out by the collator (Ex 2).
EX 4

Training Curves Screenshot

15 pts
Task: Run training on 2× L40S GPUs, monitor W&B, and attach a screenshot of your own loss curves (training loss + val loss + MMStar accuracy) to the notebook.

Launch command:

OMP_NUM_THREADS=1 torchrun --nproc_per_node=2 train.py

Expected outcome after ~30k steps: training loss ↓, val loss ↓, MMStar ~25–28%.

Reference curves (students should produce something similar):

nanoVLM-222M reference loss curves
Reference — nanoVLM-222M training & validation loss + MMStar accuracy over 30k steps.
EX 5

Autoregressive Generation

20 pts
Task: Implement VisionLanguageModel.generate() in models/vision_language_model.py — encode image + text, then iteratively sample the next token until EOS (id=2) or max_new_tokens.
📄 models/vision_language_model.py — generate()
@torch.no_grad()
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5):

    # (i) Encode image + embed text + concatenate + extend mask
    image_embd = ...
    image_embd = ...
    token_embd = ...
    combined_embed = ...
    if attention_mask is not None:
        image_attention_mask = ...
        attention_mask = ...

    outputs = combined_embed
    generated_tokens = torch.zeros((batch_size, max_new_tokens), ...)

    for i in range(...):
        model_out = ...             # full forward pass on growing sequence

        last_token_logits = ...     # last position logits only

        if ...:                     # apply LM head if in embedding mode
            last_token_logits = ...

        probs      = ...            # softmax
        next_token = ...            # multinomial sample
        generated_tokens[:, i] = next_token.squeeze(-1)

        generated_embed = ...       # embed new token
        outputs = ...               # append to sequence

        if attention_mask is not None:
            attention_mask = ...    # extend mask by 1

        if ...:                     # stop at EOS (token id = 2)
            break

    return ...
@torch.no_grad()
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5):

    image_embd = self.vision_encoder(image)          # (B, 196, 768)
    image_embd = self.MP(image_embd)                 # (B, 49,  576)

    token_embd     = self.decoder.token_embedding(input_ids)
    combined_embed = torch.cat((image_embd, token_embd), dim=1)

    batch_size, img_seq_len = image_embd.size(0), image_embd.size(1)
    if attention_mask is not None:
        image_attention_mask = torch.ones((batch_size, img_seq_len),
                                           device=attention_mask.device, dtype=attention_mask.dtype)
        attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1)

    outputs          = combined_embed
    generated_tokens = torch.zeros((batch_size, max_new_tokens),
                                    device=input_ids.device, dtype=input_ids.dtype)

    for i in range(max_new_tokens):
        model_out         = self.decoder(outputs, attention_mask)
        last_token_logits = model_out[:, -1, :]

        if not self.decoder.lm_use_tokens:
            last_token_logits = self.decoder.head(last_token_logits)

        probs      = torch.softmax(last_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(-1)

        generated_embed = self.decoder.token_embedding(next_token)
        outputs         = torch.cat((outputs, generated_embed), dim=1)

        if attention_mask is not None:
            attention_mask = torch.cat(
                (attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)), dim=1)

        if next_token.item() == 2:   # EOS
            break

    return generated_tokens
Key Insight — Naive Generation
  • At each step the full growing sequence is re-processed — cost grows quadratically. Exercise 6 fixes this.
  • Only the last position's logit matters for predicting the next token.
  • Sampling: torch.multinomial (stochastic) — responses vary across runs, unlike greedy argmax.
  • Stop condition: EOS token id = 2, or loop reaches max_new_tokens.
EX 6

KV Cache for Efficient Inference

25 pts
Task: Implement KV caching across 4 methods in models/language_model.py and models/vision_language_model.py. Replace O(T·N²) naive generation with O(N²) prefill + O(T·N) decode.
Part 1 of 4

VisionLanguageModel.generate_with_kv_cache()

📄 models/vision_language_model.py
@torch.no_grad()
def generate_with_kv_cache(self, input_ids, image, ...):

    # Step 1: Build combined embeddings (same as generate())
    image_embd    = ...
    token_embd    = ...
    combined_embd = ...

    # Step 2: PREFILL — one forward pass, collect KV cache
    model_out, past_key_values = ...

    # Step 3: First generated token
    last_logits = ...
    if not self.decoder.lm_use_tokens:
        last_logits = ...
    probs      = ...
    next_token = ...
    generated_tokens[:, 0] = next_token.squeeze(-1)

    # Step 4: DECODE LOOP
    for i in range(1, max_new_tokens):
        next_embd = ...   # embed only the last generated token (B, 1, D)

        model_out, past_key_values = ...   # 1-token forward, reuse cache

        last_logits = ...
        ...
        next_token = ...
        generated_tokens[:, i] = next_token.squeeze(-1)

        if ...:   # EOS check
            break

    return generated_tokens
@torch.no_grad()
def generate_with_kv_cache(self, input_ids, image, attention_mask=None, max_new_tokens=5):

    image_embd    = self.vision_encoder(image)
    image_embd    = self.MP(image_embd)
    token_embd    = self.decoder.token_embedding(input_ids)
    combined_embd = torch.cat((image_embd, token_embd), dim=1)
    batch_size    = image_embd.size(0)

    # PREFILL: run full prompt once, build entire KV cache
    model_out, past_key_values = self.decoder.forward_kv(combined_embd)

    last_logits = model_out[:, -1, :]
    if not self.decoder.lm_use_tokens:
        last_logits = self.decoder.head(last_logits)
    probs      = torch.softmax(last_logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)

    generated_tokens = torch.zeros((batch_size, max_new_tokens),
                                    device=input_ids.device, dtype=input_ids.dtype)
    generated_tokens[:, 0] = next_token.squeeze(-1)

    # DECODE LOOP: 1 token per step, reuse cached K/V
    for i in range(1, max_new_tokens):
        next_embd = self.decoder.token_embedding(next_token)   # (B, 1, D)

        model_out, past_key_values = self.decoder.forward_kv(
            next_embd, past_key_values=past_key_values
        )

        last_logits = model_out[:, -1, :]
        if not self.decoder.lm_use_tokens:
            last_logits = self.decoder.head(last_logits)
        probs      = torch.softmax(last_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens[:, i] = next_token.squeeze(-1)

        if next_token.item() == 2:
            break

    return generated_tokens
Part 2 of 4

LanguageModel.forward_kv()

📄 models/language_model.py
def forward_kv(self, x, past_key_values=None):
    if self.lm_use_tokens:
        x = self.token_embedding(x)
    B, T, _ = x.size()

    # Fix position IDs: new tokens continue from past_length, not 0
    past_length  = ...   # 0 if no cache, else past_key_values[0][0].size(2)
    position_ids = ...   # arange(past_length, past_length + T)
    cos, sin = self.rotary_embd(position_ids)

    present_key_values = []
    for i, block in enumerate(self.blocks):
        past_kv = ...          # this layer's cache (None on first call)
        x, present_kv = ...    # call block.forward_kv
        ...                    # append present_kv

    x = self.norm(x)
    if self.lm_use_tokens:
        x = self.head(x)

    return ...   # (hidden_states, present_key_values)
def forward_kv(self, x, past_key_values=None):
    if self.lm_use_tokens:
        x = self.token_embedding(x)
    B, T, _ = x.size()

    past_length = 0
    if past_key_values is not None:
        past_length = past_key_values[0][0].size(2)  # layer-0 key length

    position_ids = torch.arange(past_length, past_length + T,
                                 device=x.device).unsqueeze(0).expand(B, -1)
    cos, sin = self.rotary_embd(position_ids)

    present_key_values = []
    for i, block in enumerate(self.blocks):
        past_kv = past_key_values[i] if past_key_values is not None else None
        x, present_kv = block.forward_kv(x, cos, sin, past_kv)
        present_key_values.append(present_kv)

    x = self.norm(x)
    if self.lm_use_tokens:
        x = self.head(x)

    return x, present_key_values
Part 3 of 4

LanguageModelBlock.forward_kv()

📄 models/language_model.py
def forward_kv(self, x, cos, sin, past_key_value=None):
    res = x
    x   = self.norm1(x)
    x, present_key_value = ...   # call self.attn.forward_kv(...)
    x = res + x

    res = x
    x   = self.norm2(x)
    x   = self.mlp(x)
    x   = res + x

    return ...   # (hidden_states, present_key_value)
def forward_kv(self, x, cos, sin, past_key_value=None):
    res = x
    x   = self.norm1(x)
    x, present_key_value = self.attn.forward_kv(x, cos, sin, past_key_value)
    x = res + x

    res = x
    x   = self.norm2(x)
    x   = self.mlp(x)
    x   = res + x

    return x, present_key_value
Part 4 of 4

LanguageModelGroupedQueryAttention.forward_kv()

📄 models/language_model.py
def forward_kv(self, x, cos, sin, past_key_value=None):
    B, T, C = x.size()
    q = self.q_proj(x).view(B, T, self.n_heads,    self.head_dim).transpose(1, 2)
    k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
    v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

    q, k = apply_rotary_pos_embd(q, k, cos, sin)   # RoPE applied BEFORE concat!

    # Step 1: Prepend cached K/V if this is a decode step
    if past_key_value is not None:
        past_k, past_v = ...
        k = ...
        v = ...

    # Step 2: Save present (full) cache
    present_key_value = ...

    k = k.repeat_interleave(self.n_kv_groups, dim=1)
    v = v.repeat_interleave(self.n_kv_groups, dim=1)

    # Step 3: Attend — causal during prefill only
    is_decode = ...
    y = F.scaled_dot_product_attention(q, k, v, is_causal=...)

    y = y.transpose(1, 2).contiguous().view(B, T, C)
    y = self.out_proj(y)

    # Step 4: Return output + cache
    return ...
def forward_kv(self, x, cos, sin, past_key_value=None):
    B, T, C = x.size()
    q = self.q_proj(x).view(B, T, self.n_heads,    self.head_dim).transpose(1, 2)
    k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
    v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

    q, k = apply_rotary_pos_embd(q, k, cos, sin)   # RoPE applied BEFORE concat!

    # Prepend cached K/V
    if past_key_value is not None:
        past_k, past_v = past_key_value
        k = torch.cat([past_k, k], dim=2)   # (B, n_kv_heads, past_len+T, head_dim)
        v = torch.cat([past_v, v], dim=2)

    # Store full K,V for the next step
    present_key_value = (k, v)

    k = k.repeat_interleave(self.n_kv_groups, dim=1)
    v = v.repeat_interleave(self.n_kv_groups, dim=1)

    is_decode = past_key_value is not None
    y = F.scaled_dot_product_attention(q, k, v,
                                        is_causal=not is_decode)  # no causal mask during decode

    y = y.transpose(1, 2).contiguous().view(B, T, C)
    y = self.out_proj(y)

    return y, present_key_value
Key Insight — Two-Phase Caching
  • PREFILL: run full prompt once → builds K/V cache for every layer.
  • DECODE: feed only 1 new token per step → prepend cached K/V, compute attention over full history at O(N+t) cost.
  • Critical: RoPE is applied before concat, and position_ids must start from past_length — not 0.
  • Causal mask: needed during prefill (T_q > 1), not during decode (T_q = 1 trivially attends all past tokens).