COM-304: Foundation models (FM) track โ€” NanoMaskGIT exercise solutions

2nd rundown session  ยท  25th March 2026

MaskGIT vs. Next-Token Prediction

Unlike autoregressive models that generate one token at a time left-to-right, MaskGIT decodes multiple tokens in parallel at each step, leading to significantly faster inference.

MaskGIT vs next-token prediction sampling

Part 1 โ€” Image Generation with nanoMaskGIT

Image

2.1.1  Initialize nanoMaskGIT

5 points

This task requires initializing the main building blocks of nanoMaskGIT: token embedding layer, transformer trunk (encoder), positional embeddings (PE), learnable mask token, output norm, and the embedding-to-logit layer.

File: nano4M/nanofm/models/maskgit.py โ€” MaskGIT.__init__

def __init__(
    self,
    seq_read_key: str = 'input_ids',
    dim: int = 512,
    depth: int = 8,
    head_dim: int = 64,
    mlp_ratio: float = 4.0,
    use_bias: bool = False,
    vocab_size: int = 10000,
    seq_len: int = 256,
    init_std: float = 0.02,
):
    super().__init__()
    self.seq_read_key = seq_read_key
    self.init_std = init_std

    self.input_embedding = nn.Embedding(vocab_size, dim)
    self.positional_embedding = nn.Parameter(torch.randn(seq_len, dim))
    self.mask_token = nn.Parameter(torch.randn(dim))

    self.trunk = TransformerTrunk(
        dim=dim, depth=depth, head_dim=head_dim,
        mlp_ratio=mlp_ratio, use_bias=use_bias
    )

    self.out_norm = LayerNorm(dim, bias=use_bias)
    self.to_logits = nn.Linear(dim, vocab_size, bias=False)

    self.initialize_weights()

2.1.2  Implement the forward function and loss

10 points

File: nano4M/nanofm/models/maskgit.py โ€” forward_model, compute_ce_loss

forward_model diagram
def forward_model(self, x: torch.LongTensor, mask: torch.BoolTensor) -> torch.Tensor:
    B, L = x.size()
    x = self.input_embedding(x)       # (B, L, D)
    x[mask] = self.mask_token          # replace masked positions with learned mask token
    x = x + self.positional_embedding  # add learned positional embeddings
    x = self.trunk(x)                  # full bi-directional self-attention
    logits = self.to_logits(self.out_norm(x))
    return logits


def compute_ce_loss(
    self,
    logits: torch.Tensor,
    target_seq: torch.LongTensor,
    ignore_index: int = -100
) -> torch.Tensor:
    B, L, vocab_size = logits.shape
    loss = F.cross_entropy(
        logits.reshape(-1, vocab_size),
        target_seq.reshape(-1),
        ignore_index=ignore_index
    )
    return loss

2.1.3  Implement random masking

15 points

File: nano4M/nanofm/models/maskgit.py โ€” generate_random_mask

def generate_random_mask(self, seq: torch.Tensor) -> torch.BoolTensor:
    B, L = seq.size()
    m = torch.randint(1, L + 1, (B,), device=seq.device)   # (B,)
    random_scores = torch.rand(B, L, device=seq.device)      # (B, L)
    # argsort twice: first gives ordering, second gives rank for each position
    ranks = random_scores.argsort(dim=1).argsort(dim=1)      # (B, L)
    return ranks < m.unsqueeze(1)                             # (B, L) bool
random masking diagram

2.1.4  MaskGIT schedule and generation function

20 points

File: nano4M/nanofm/models/maskgit.py โ€” get_maskgit_schedule, generate

def get_maskgit_schedule(self, mask: torch.BoolTensor, num_steps: int = 8) -> List[int]:
    total_tokens = int(mask.sum().item())

    assert total_tokens > 0,           "No tokens to unmask."
    assert num_steps > 0,              "num_steps must be > 0."
    assert num_steps <= total_tokens,  "num_steps must be <= total masked tokens."

    tokens_per_step = total_tokens // num_steps
    remainder       = total_tokens  % num_steps
    schedule        = [tokens_per_step] * num_steps
    schedule[-1]   += remainder

    assert sum(schedule) == total_tokens
    return schedule
MaskGIT schedule and random masking diagram
@torch.no_grad()
def generate(
        self,
        seq:          torch.LongTensor,
        mask:         torch.BoolTensor,
        num_steps:    int   = 8,
        temp:         float = 1.0,
        top_p:        float = 0.0,
        top_k:        float = 0.0,
        return_history: bool = False,
) -> torch.Tensor:
    L = seq.size(0)
    schedule = self.get_maskgit_schedule(mask, num_steps)

    seq  = seq.unsqueeze(0)   # (1, L)
    mask = mask.unsqueeze(0)  # (1, L)

    if return_history:
        seq_history  = [seq.clone().cpu()]
        mask_history = [mask.clone().cpu()]

    for step, k in enumerate(schedule):
        logits = self.forward_model(seq, mask)          # (1, L, vocab_size)

        masked_indices = torch.where(mask[0])[0]        # (M,) positions still masked
        masked_logits  = logits[0, masked_indices, :]   # (M, vocab_size)

        confidence = masked_logits.max(dim=-1)[0]       # (M,)

        _, topk_idx        = torch.topk(confidence, k)
        selected_positions = masked_indices[topk_idx]   # (k,)
        selected_logits    = logits[0, selected_positions, :]  # (k, vocab_size)

        samples, _ = sample_tokens(selected_logits, temperature=temp, top_k=top_k, top_p=top_p)

        seq[0, selected_positions]  = samples
        mask[0, selected_positions] = False

        if return_history:
            seq_history.append(seq.clone().cpu())
            mask_history.append(mask.clone().cpu())

    if return_history:
        return torch.cat(seq_history, dim=0), torch.cat(mask_history, dim=0)
    return seq

2.3  Show your loss curves

10 points
nanoMaskGIT MNIST loss curves

2.4  Evaluating the model

10 points
Evaluation results

2.5  Open-ended questions

โ€”

The model first decodes the boundary tokens with highest confidence and gradually carves out the digit tokens towards the end.

Intermediate decoding steps

Effect of different number of decoding steps: too few steps (e.g. k=2) or too many both hurt quality โ€” with a high k the model does not leverage enough context. A sweet spot of k=8 to k=16 gives the best results.

Effect of decoding steps

Part 2 โ€” Text Generation on TinyStories

Text

3.3  Training model on TinyStories (loss curves)

โ€”
nanoMaskGIT TinyStories loss curves

3.4  Evaluating the model โ€” Generating text using our trained model

โ€”
Once upon a time, there was an old newspaper.

--------------------------------------------------

Once upon a time, there was a large boy named Max. Max loved to play on the big with his toy dog, Max. One day, they went to the park to play on

Max and Max saw on the swings and a pretty slide. Max. a big tree was an orange tree. Max was to climb with Max fun, but he said to himself and kept, " Max, do have this orange tree!" Max
MaxMax looked up and saw a big pile of branches. Max ran up the tree to get. Max jumped and The tree was very but he didn't reach the tree. Finally, Max saw the orange. Max was sad happy again but then was happy little longer

Max and Max climbed the tree too. They jumped high watched on orange. in the big could go up Max loved running. catching orange orange was orange. Max was "Hi, Max! Both, here, slide until Max!" Then, jumped and got, and Max! Max was to the top. Max stood back behind forth, Max saw the orange tree, his paw. The orange tree was a big dog and Max. Max cheered and jumped on Max. dog said said to Max, and Max all went to the park. They were happy that

--------------------------------------------------

Once upon a time there there was an ordinary like story..

--------------------------------------------------

S upon a time, there was a little girl named Sara. She loved something special she wanted the day.!

--------------------------------------------------

One upon a time there was a lawyer.

3.5  Open-ended questions

โ€”

Intermediate generation steps

--- Step 0 ---
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

--- Step 1 ---
Once upon!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

--- Step 2 ---
Once upon a time!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

--- Step 3 ---
Once upon a time, there!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

--- Step 4 ---
Once upon a time, there was a!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

--- Step 5 ---
Once upon a time, there was a little boy!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

--- Step 6 ---
Once upon a time, there was a little boy named Tim!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

The model automatically discovers that generating text from left to right is the most confident strategy. Decoding one token per step leads to an autoregressive left-to-right generation pattern in nanoMaskGIT.

Effect of number of decoding steps (k)

--- Generation with k = 1 ---
Emily upon a time young was a a little to who in She One a in look's on family very. would wanted. The day, One all to was would.
mommy his and to came a too She. had and, was They's. ". got a and the had He day, and was She They The her She, Tim but.'s started she you
in but. that I .!" He his a. and and girl ran for it." very he the. said girl, put to. they . it to the so to " She that
. to little was . a up mom.I.
that Lily and..
the and and it with had so. . I,'s
the so of happy their and was
her too could and mom and " for the . . Tom you in She " Mom
and play. 's.
,

--- Generation with k = 4 ---
Once upon a time, and called He was Jack loved with study. things to day. do day He with
his study he
in that learn One He to
a he. and

excited the he.. very for he thought He lots Jack
at said He book,. excited little learned
and him
words the tried..!
study,'s thought something to asked,. to He he learning and what
lots the. did of and study! he and He. started looked.. when words the words and to It The studied ". and! his him he the.
...

--- Generation with k = 256 ---
Once upon a time, there was a little boy named Timmy who liked to play outside. One day, while playing, he found a rope and decided to swing it. It was a perfect place to slide down. He swung high and low until he was having so much fun! When he woke up, he decided to go back and forth again. He couldn't wait for next day. The end.

Using more decoding steps for text generation yields relatively good results. With k=256 (one token decoded per step) the model produces coherent text, while very low k values (e.g. k=1 or k=4) produce incoherent output since the model cannot leverage sufficient context.