Vae Vqvae Vqgan Tutorial En

Source: docs/tutorials/vae_vqvae_vqgan_tutorial_en.md SHA256: e6f6353f05d0 Rendered: 2026-05-19 19:02 UTC

§0 TL;DR Cheat Sheet

VAE / VQ-VAE / VQ-GAN / FSQ in 8 sentences

one page covering the interview essentials (see §2–§9 for derivations).

  1. Continuous VAE objective: maximize ELBO, $\log p(x) \geq \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}(q_\phi(z|x)\,\|\,p(z))$; reparameterization $z = \mu + \sigma \odot \epsilon$ lets gradients flow through stochastic sampling.
  2. KL closed form (must-know): $D_\text{KL}(\mathcal{N}(\mu,\sigma^2 I)\,\|\,\mathcal{N}(0,I)) = \tfrac{1}{2}\sum_i (\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1)$.
  3. Posterior collapse: KL → 0 → decoder ignores $z$; mitigations: KL annealing, free bits, $\beta$ schedule, autoregressive prior.
  4. VQ-VAE: maps encoder output $z_e(x)$ to the nearest codebook vector $e_k$, loss = recon + $\|\text{sg}[z_e] - e\|^2$ (codebook) $+ \beta \|z_e - \text{sg}[e]\|^2$ (commitment).
  5. Straight-Through Estimator (STE): argmin / quantize is non-differentiable, use quantized value in forward, pass-through gradient $\partial \mathcal{L}/\partial z_q \to \partial \mathcal{L}/\partial z_e$ in backward.
  6. VQ-GAN: VQ-VAE + perceptual (LPIPS) + adversarial (PatchGAN) + post-trained Transformer prior; lays the foundation for LDM / Parti / Muse and other discrete-token models.
  7. FSQ (2024): per-dimension scalar quantization to $\{-L,\ldots,L\}$, implicit codebook size $\prod_i L_i$ (e.g. $L=8, d=6 \Rightarrow 8^6 = 262{144}$), no need for STE, no codebook collapse, rounding uses STE only, loss only has reconstruction.
  8. Ecosystem comparison: continuous latent (VAE / KL) suits LDM-style diffusion; discrete tokens (VQ-VAE / VQ-GAN / FSQ / LFQ) suit AR / MaskGIT Transformer priors, the core component of Parti / Muse / Cosmos.

§1 Intuition: Why Latent Variable Models

The core challenge of generative modeling: directly modeling $p(x)$ is hard, but if we introduce low-dim latent $z$:

$$p(x) = \int p(x|z)\, p(z)\, dz$$

we can decompose "complex image distribution" into "simple prior $p(z)$ (e.g. $\mathcal{N}(0, I)$)" plus "easy-to-learn conditional $p(x|z)$." Two paths:

Training vs inference asymmetry

VAE/VQ-VAE trains the full encoder + decoder (rate-distortion view: "compress-reconstruct"); at inference there are two cases depending on application:

§2 VAE: Core Formulas and Derivations

2.1 ELBO derivation (must-know, derive line by line)

Model family $p_\theta(x, z) = p_\theta(x|z)\, p(z)$, prior $p(z) = \mathcal{N}(0, I)$, likelihood $p_\theta(x|z)$ given by decoder. Marginal likelihood:

$$\log p_\theta(x) = \log \int p_\theta(x|z) p(z)\, dz$$

For any distribution $q_\phi(z|x)$ (encoder / variational posterior, $q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \mathrm{diag}(\sigma_\phi^2(x)))$), by Jensen's inequality / direct substitution:

$$ \begin{aligned} \log p_\theta(x) &= \log \int q_\phi(z|x) \frac{p_\theta(x|z) p(z)}{q_\phi(z|x)} dz \\ &\geq \mathbb{E}_{q_\phi(z|x)}\!\left[\log \frac{p_\theta(x|z) p(z)}{q_\phi(z|x)}\right] \quad \text{(Jensen)} \\ &= \underbrace{\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]}_{\text{reconstruction (negative)}} - \underbrace{D_\text{KL}(q_\phi(z|x)\,\|\,p(z))}_{\text{regularization}} \end{aligned} $$

So ELBO:

$$\boxed{\;\mathcal{L}_\text{ELBO}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}\!\big(q_\phi(z|x)\,\|\,p(z)\big)\;}$$

Tightness cost: the gap between ELBO and the true log-likelihood is $D_\text{KL}(q_\phi(z|x)\,\|\,p_\theta(z|x))$. The closer $q_\phi$ approximates the true posterior $p_\theta(z|x)$, the smaller the gap.

2.2 KL term closed-form derivation (L3 must derive)

Let $q_\phi(z|x) = \mathcal{N}(\mu, \mathrm{diag}(\sigma^2))$ (diagonal covariance, per-dim $\sigma_i^2$), $p(z) = \mathcal{N}(0, I)$.

For each dimension $i$ independently:

$$ \begin{aligned} D_\text{KL}(\mathcal{N}(\mu_i, \sigma_i^2) \,\|\, \mathcal{N}(0, 1)) &= \int \mathcal{N}(z; \mu_i, \sigma_i^2) \log \frac{\mathcal{N}(z; \mu_i, \sigma_i^2)}{\mathcal{N}(z; 0, 1)} dz \end{aligned} $$

Expanding the log of two Gaussian densities:

$$ \log \frac{\mathcal{N}(z; \mu_i, \sigma_i^2)}{\mathcal{N}(z; 0, 1)} = -\tfrac{1}{2}\log \sigma_i^2 - \tfrac{(z-\mu_i)^2}{2\sigma_i^2} + \tfrac{z^2}{2} $$

Taking expectation (using $\mathbb{E}_q[z] = \mu_i$, $\mathbb{E}_q[z^2] = \mu_i^2 + \sigma_i^2$, $\mathbb{E}_q[(z-\mu_i)^2] = \sigma_i^2$):

$$ \begin{aligned} D_\text{KL} &= -\tfrac{1}{2}\log \sigma_i^2 - \tfrac{1}{2} + \tfrac{1}{2}(\mu_i^2 + \sigma_i^2) \\ &= \tfrac{1}{2}\big(\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1\big) \end{aligned} $$

Sum over all dimensions:

$$\boxed{\;D_\text{KL}\big(\mathcal{N}(\mu, \mathrm{diag}(\sigma^2)) \,\|\, \mathcal{N}(0, I)\big) = \tfrac{1}{2}\sum_{i=1}^{d}\big(\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1\big)\;}$$

Numerical stability

In implementation have the encoder output $\log \sigma^2$ (log-variance) rather than $\sigma$, to avoid overflow when taking exp on $\sigma$. In code: kl = 0.5 * (mu**2 + logvar.exp() - logvar - 1).sum().

2.3 Reparameterization Trick (must-know)

The $\mathbb{E}_{q_\phi(z|x)}[\cdot]$ in ELBO is estimated by Monte Carlo: sample one $z \sim q_\phi(z|x)$, compute $\log p_\theta(x|z)$.

Problem: directly sampling $z = \text{sample}(\mathcal{N}(\mu, \sigma^2))$ is non-differentiable, gradients cannot flow back to $\phi$.

Solution: move the randomness into an independent noise:

$$\boxed{\;z = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\;}$$

Now $z$ is a deterministic function of $\phi$ (conditioned on $\epsilon$), and $\nabla_\phi \mathcal{L}$ can backpropagate normally. This is one of the core contributions of Kingma & Welling (ICLR 2014).

Interview bonus: reparameterization is not just for Gaussians

Concrete / Gumbel-softmax (§7) plays a similar trick for discrete variables: replace argmax with softmax + Gumbel noise, approximating discrete in forward, using softmax gradient in backward.

2.4 VAE training loss (practical formulation)

Negative ELBO (to minimize):

$$\mathcal{L}_\text{VAE}(x) = \underbrace{\|x - \hat{x}\|^2}_{\text{recon (Gaussian likelihood up to const)}} + \underbrace{D_\text{KL}(q_\phi(z|x)\,\|\,p(z))}_{\text{closed form}}$$

For Bernoulli / Categorical likelihoods (e.g. binary MNIST), replace the recon term with BCE / CE.

§3 Complete VAE Implementation (PyTorch)

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    """ Classic VAE: Gaussian encoder + Gaussian/Bernoulli decoder
        This implementation uses MNIST (28×28) as example, latent dim=20
        For production swap MLP for ResNet / U-Net encoder/decoder, latent can be a spatial map """

    def __init__(self, x_dim: int = 784, h_dim: int = 400, z_dim: int = 20,
                 likelihood: str = "bernoulli"):
        super().__init__()
        self.x_dim, self.z_dim = x_dim, z_dim
        self.likelihood = likelihood

        # Encoder: x -> (μ, logσ²)
        self.enc = nn.Sequential(
            nn.Linear(x_dim, h_dim), nn.ReLU(),
            nn.Linear(h_dim, h_dim), nn.ReLU(),
        )
        self.fc_mu = nn.Linear(h_dim, z_dim)
        self.fc_logvar = nn.Linear(h_dim, z_dim)

        # Decoder: z -> x̂
        self.dec = nn.Sequential(
            nn.Linear(z_dim, h_dim), nn.ReLU(),
            nn.Linear(h_dim, h_dim), nn.ReLU(),
            nn.Linear(h_dim, x_dim),
        )

    def encode(self, x: torch.Tensor):
        h = self.enc(x.view(x.size(0), -1))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor):
        # z = μ + σ ⊙ ε,  σ = exp(0.5 · logvar)
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + std * eps
        else:
            # At inference use the posterior mean (deterministic)
            return mu

    def decode(self, z: torch.Tensor):
        logits = self.dec(z)
        if self.likelihood == "bernoulli":
            return torch.sigmoid(logits), logits
        return logits, logits  # Gaussian likelihood: treated as mean prediction

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat, logits = self.decode(z)
        return x_hat, logits, mu, logvar


def vae_loss(x: torch.Tensor, logits: torch.Tensor, mu: torch.Tensor,
             logvar: torch.Tensor, likelihood: str = "bernoulli",
             beta: float = 1.0, free_bits: float = 0.0):
    """ Returns:
            (loss, recon, kl)
        beta:        β in β-VAE (default 1 = standard VAE)
        free_bits:   per-dim KL lower bound (nats). When > 0 enables free bits"""
    B = x.size(0)
    x_flat = x.view(B, -1)

    # 1) Reconstruction term: -E_q[log p(x|z)]
    if likelihood == "bernoulli":
        # BCE-with-logits is more numerically stable, equivalent to -log Bernoulli likelihood
        recon = F.binary_cross_entropy_with_logits(
            logits, x_flat, reduction="sum") / B
    elif likelihood == "gaussian":
        # Assuming σ² = 1 (constant), MSE differs from negative log-Gaussian by a constant
        recon = 0.5 * F.mse_loss(logits, x_flat, reduction="sum") / B
    else:
        raise ValueError(likelihood)

    # 2) KL term: D_KL(N(μ, σ²) || N(0, I))  closed form
    kl_per_dim = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1)   # [B, z_dim]

    if free_bits > 0:
        # Free bits: per-dim KL lower bound = free_bits (mitigates posterior collapse)
        kl_per_dim = torch.clamp(kl_per_dim, min=free_bits)

    kl = kl_per_dim.sum(dim=-1).mean()                            # scalar

    loss = recon + beta * kl
    return loss, recon, kl
Common bug list

Pitfalls when writing VAE.

3.1 Training loop + KL annealing

def train_vae(model, dataloader, total_steps=50_000, lr=1e-3, device="cuda",
              beta_max=1.0, anneal_steps=10_000, free_bits=0.0):
    """ KL annealing: β linearly grows from 0 to beta_max, preventing posterior from collapsing in early training """
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device).train()

    step = 0
    while step < total_steps:
        for x, _ in dataloader:
            x = x.to(device)
            beta = min(beta_max, beta_max * step / max(anneal_steps, 1))

            x_hat, logits, mu, logvar = model(x)
            loss, recon, kl = vae_loss(x, logits, mu, logvar,
                                       beta=beta, free_bits=free_bits)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()
            step += 1
            if step >= total_steps:
                break

§4 VAE Variants: $\beta$-VAE / IWAE / NVAE / VAE-GAN

4.1 $\beta$-VAE (Higgins et al., ICLR 2017)

Weight the KL of ELBO by $\beta$:

$$\mathcal{L}_{\beta\text{-VAE}}(\theta, \phi; x) = \mathbb{E}_{q_\phi}[\log p_\theta(x|z)] - \beta \cdot D_\text{KL}(q_\phi(z|x)\,\|\,p(z))$$

Disentanglement controversy

Locatello et al. (ICML 2019, best paper) proved: pure unsupervised disentanglement is impossible without inductive bias / supervision. $\beta$-VAE's "emergent disentanglement" largely depends on architecture + dataset bias, not $\beta$ itself.

4.2 IWAE: Importance Weighted Autoencoder (Burda et al., ICLR 2016)

ELBO is a first-order bound on $\log p(x)$. With $K$ importance samples get a tighter bound:

$$\mathcal{L}_K^\text{IWAE}(x) = \mathbb{E}_{z_1,\ldots,z_K \sim q_\phi}\!\left[\log \frac{1}{K}\sum_{k=1}^K \frac{p_\theta(x, z_k)}{q_\phi(z_k|x)}\right]$$

Properties:

Tradeoff

IWAE makes the likelihood bound tighter, but the encoder's learned posterior no longer pursues "approximating the true posterior," instead aligning with the geometry of importance weighting. Not necessarily better for downstream representation learning.

4.3 NVAE: Hierarchical VAE (Vahdat & Kautz, NeurIPS 2020)

Multi-layer latent $z = (z_1, z_2, \ldots, z_L)$, each layer depends on the previous:

$$p(z) = p(z_1)\prod_{l=2}^L p(z_l | z_{

Engineering essentials:

NVAE's current role: one of the strongest continuous VAE priors before LDM, but surpassed in sample quality by diffusion series.

4.4 VAE-GAN (Larsen et al., ICML 2016)

VAE's reconstruction loss (pixel-MSE / BCE) is insensitive to high-frequency details → blurry generations. VAE-GAN replaces / supplements MSE with discriminator feature matching:

$$\mathcal{L}_\text{recon}^\text{VAE-GAN} = \|D_l(x) - D_l(\hat{x})\|^2$$

where $D_l$ is the discriminator's intermediate layer features. Combined with adversarial loss, reconstruction is perceptually sharper.

This idea culminates in VQ-GAN (§6): VQ-VAE framework + perceptual + adversarial + high-bitrate codebook + Transformer prior.

§5 Posterior Collapse (must-know)

5.1 Phenomenon

During training $D_\text{KL}(q_\phi(z|x)\,\|\,p(z)) \to 0$, i.e. $q_\phi(z|x) \approx p(z)$, independent of $x$. Consequence: decoder completely ignores $z$, VAE degenerates to an unconditional generative model.

5.2 Causes (intuitive analysis)

5.3 Mitigation methods (interview must list)

MethodHowSource
KL annealing$\beta(t) = \min(1, t / T)$ linear from 0 to 1Bowman et al. (2016)
Free bitsPer-dim KL lower bound $\lambda$ nats: $\max(D_\text{KL}^{(i)}, \lambda)$Kingma et al. (2016)
$\beta$ < 1Directly reduce KL weight$\beta$-VAE reverse usage
Weakened decoderWhen using strong AR decoders like PixelCNN, manually truncate context / add dropoutChen et al. (2017)
Auxiliary taskAdd word dropout, bag-of-words auxiliary lossBowman et al. (2016)
VAE-IAF / NF priorUse more complex prior or normalizing flow posteriorKingma et al. (2016)
Skip / lateral connectionsForce latent to participate in decoder (e.g. VLAE)Zhao et al. (2017)
VQ-VAEDiscrete latent + codebook commitment, structurally avoids collapsevan den Oord (2017)
Free bits formula

Implementation is extremely simple: kl_per_dim = max(kl_per_dim, λ). Intuition: guarantee a baseline of λ bits of information per latent dim, the optimizer cannot push it below 0. Common values $\lambda \approx 0.5$-$2$ nats / dim.

§6 VQ-VAE: Discrete Latent + Codebook + STE

6.1 Structure (van den Oord, Vinyals, Kavukcuoglu, NeurIPS 2017)


x ──Encoder──> z_e(x) ∈ R^{H'×W'×D}    # continuous spatial map
                       │
                       │   For each spatial position (h,w), find nearest codebook vector
                       │   k_{hw} = argmin_k ‖z_e(x)_{hw} - e_k‖²
                       ↓
            z_q(x)_{hw} = e_{k_{hw}}    # quantized spatial map (discrete code)
                       │
                       │
                       ↓
                   Decoder ──> x̂

Codebook $\mathcal{E} = \{e_1, \ldots, e_K\} \subset \mathbb{R}^D$, learnable. $z_e(x)$ and $z_q(x)$ have the same shape, but each spatial position of $z_q$ is a copy of some codebook vector (discrete index $k_{hw}$).

6.2 Loss derivation

VQ-VAE does not learn a stochastic posterior $q(z|x)$ (unlike VAE), but uses deterministic nearest neighbor for the "quantization" $z_e \to z_q$. The loss has three parts:

$$\boxed{\;\mathcal{L}_\text{VQ-VAE} = \underbrace{\|x - \hat{x}\|^2}_{\text{reconstruction}} + \underbrace{\|\text{sg}[z_e(x)] - e\|^2}_{\text{codebook}} + \beta \underbrace{\|z_e(x) - \text{sg}[e]\|^2}_{\text{commitment}}\;}$$

Meaning of each term:

sg[·] = stop_gradient (PyTorch's .detach()), defined: forward $\text{sg}[u] = u$, backward $\nabla \text{sg}[u] = 0$.

Why both codebook and commitment need sg

If neither has sg, $\|z_e - e\|^2$ pulls both sides simultaneously, with coupled directions easily oscillating. Split this into two sg versions: codebook term updates $e$ exclusively, commitment exclusively updates encoder, learning rate / speed can be decoupled. This is standard practice in the vector quantization literature (also called "alternating minimization").

6.3 Straight-Through Estimator (STE) derivation

Problem: $z_q = e_{\arg\min_k \|z_e - e_k\|^2}$'s argmin is non-differentiable (outputs discrete index).

STE solution:

PyTorch implementation trick (classic three lines):

z_q = z_e + (z_q_quantized - z_e).detach()

Forward: z_q = z_e + (z_q_q - z_e) = z_q_q ✓ (quantized value) Backward: (z_q_q - z_e).detach() does not participate in gradient, so dz_q/dz_e = 1, gradient flows straight through to encoder ✓

STE's equivalent surrogate

STE is equivalent to replacing the non-differentiable $z_q = \text{quantize}(z_e)$ with a differentiable surrogate $z_q^\text{surrogate} = z_e$ for backprop — i.e. "assume quantization is the identity map." This is a biased estimate (biased gradient estimator), but works well in practice; theoretical analysis in Bengio et al. (2013) "Estimating or Propagating Gradients Through Stochastic Neurons."

6.4 Complete VQ-VAE implementation

class VectorQuantizer(nn.Module):
    """ Codebook + nearest-neighbor quantization + STE
        embedding_dim = D, num_embeddings = K
        commitment_cost β usually = 0.25 """

    def __init__(self, num_embeddings: int, embedding_dim: int,
                 commitment_cost: float = 0.25):
        super().__init__()
        self.K, self.D = num_embeddings, embedding_dim
        self.beta = commitment_cost
        # Codebook small uniform init
        self.codebook = nn.Embedding(self.K, self.D)
        self.codebook.weight.data.uniform_(-1.0 / self.K, 1.0 / self.K)

    def forward(self, z_e: torch.Tensor):
        """ z_e: [B, D, H, W]  ->  z_q: [B, D, H, W], indices: [B, H, W],
            loss = codebook_loss + β·commitment_loss """
        # 1) Reshape: [B, D, H, W] -> [BHW, D]
        B, D, H, W = z_e.shape
        z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D)   # [BHW, D]

        # 2) Compute L2 distance  ‖z_e - e_k‖² = ‖z_e‖² + ‖e_k‖² - 2 z_e · e_k
        e = self.codebook.weight                                       # [K, D]
        dist = (z_e_flat.pow(2).sum(1, keepdim=True)
                + e.pow(2).sum(1)
                - 2 * z_e_flat @ e.t())                                # [BHW, K]

        # 3) Nearest neighbor index
        indices = dist.argmin(dim=1)                                   # [BHW]
        z_q_flat = self.codebook(indices)                              # [BHW, D]

        # 4) Loss (mind the sg)
        codebook_loss = F.mse_loss(z_q_flat, z_e_flat.detach())
        commitment_loss = F.mse_loss(z_e_flat, z_q_flat.detach())
        vq_loss = codebook_loss + self.beta * commitment_loss

        # 5) STE: forward z_q, backward dz_q/dz_e = I
        z_q_flat = z_e_flat + (z_q_flat - z_e_flat).detach()

        # 6) Reshape back to [B, D, H, W]
        z_q = z_q_flat.view(B, H, W, D).permute(0, 3, 1, 2).contiguous()
        indices = indices.view(B, H, W)

        # 7) (optional) perplexity: measure of codebook usage
        one_hot = F.one_hot(indices.view(-1), self.K).float()
        avg_probs = one_hot.mean(dim=0)
        perplexity = torch.exp(-(avg_probs * torch.log(avg_probs + 1e-10)).sum())

        return z_q, indices, vq_loss, perplexity


class VQVAE(nn.Module):
    def __init__(self, channels=3, hidden=128, num_embeddings=512, embedding_dim=64,
                 commitment_cost=0.25):
        super().__init__()
        # Encoder: 64×64×3 -> 16×16×D  (downsample 4×)
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, hidden, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(hidden, hidden, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(hidden, embedding_dim, 3, 1, 1),
        )
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
        # Decoder: 16×16×D -> 64×64×3
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, hidden, 3, 1, 1), nn.ReLU(),
            nn.ConvTranspose2d(hidden, hidden, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(hidden, channels, 4, 2, 1),
        )

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, indices, vq_loss, perplexity = self.quantizer(z_e)
        x_hat = self.decoder(z_q)
        return x_hat, vq_loss, perplexity, indices

def vqvae_loss(x, x_hat, vq_loss):
    recon = F.mse_loss(x_hat, x)
    return recon + vq_loss, recon

6.5 EMA Codebook (production standard)

Directly using codebook loss to update $e$ converges slowly with many dead codes (codebook vectors never selected). Production implementation uses EMA (Exponential Moving Average) update:

For each codebook vector $e_k$, maintain:

Update:

$$e_k^{(t)} = \frac{m_k^{(t)}}{N_k^{(t)} + \varepsilon} \quad \text{(Laplace smoothing)}$$

class VectorQuantizerEMA(nn.Module):
    """ EMA codebook update (van den Oord 2017 follow-up / VQ-VAE-2 standard practice)
        - codebook does not rely on gradients, but on running EMA updates
        - commitment loss is kept for the encoder
        - decay γ is typically 0.99, ε typically 1e-5 """

    def __init__(self, num_embeddings: int, embedding_dim: int,
                 commitment_cost: float = 0.25, decay: float = 0.99, eps: float = 1e-5):
        super().__init__()
        self.K, self.D = num_embeddings, embedding_dim
        self.beta, self.decay, self.eps = commitment_cost, decay, eps

        embed = torch.randn(num_embeddings, embedding_dim) * 0.01
        self.register_buffer("codebook", embed)
        self.register_buffer("cluster_size", torch.zeros(num_embeddings))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, z_e):
        B, D, H, W = z_e.shape
        z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D)

        dist = (z_e_flat.pow(2).sum(1, keepdim=True)
                + self.codebook.pow(2).sum(1)
                - 2 * z_e_flat @ self.codebook.t())
        indices = dist.argmin(dim=1)                                # [BHW]
        z_q_flat = F.embedding(indices, self.codebook)              # [BHW, D]

        if self.training:
            # EMA update
            with torch.no_grad():
                one_hot = F.one_hot(indices, self.K).float()        # [BHW, K]
                cluster_size_new = one_hot.sum(dim=0)               # [K]
                embed_sum = one_hot.t() @ z_e_flat                  # [K, D]

                self.cluster_size.mul_(self.decay).add_(cluster_size_new, alpha=1 - self.decay)
                self.embed_avg.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)

                # Laplace smoothing to avoid division by zero
                n = self.cluster_size.sum()
                cluster_size = (self.cluster_size + self.eps) / (n + self.K * self.eps) * n
                self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1))

        commitment_loss = F.mse_loss(z_e_flat, z_q_flat.detach())
        vq_loss = self.beta * commitment_loss                       # under EMA, no codebook loss

        z_q_flat = z_e_flat + (z_q_flat - z_e_flat).detach()        # STE
        z_q = z_q_flat.view(B, H, W, D).permute(0, 3, 1, 2).contiguous()
        return z_q, indices.view(B, H, W), vq_loss
Two benefits of EMA

6.6 VQ-VAE-2 (Razavi, Vinyals, van den Oord, NeurIPS 2019)

Hierarchical extension of VQ-VAE:

On ImageNet 256×256, the first time a VQ-based method approached BigGAN's sample quality; the direct predecessor to VQ-GAN.

§7 VQ-GAN: Adversarial + Perceptual + Transformer Prior

7.1 Core idea (Esser, Rombach, Ommer, CVPR 2021, "Taming Transformers")

VQ-VAE reconstruction on ImageNet has blurry texture details. VQ-GAN transforms to:

ComponentVQ-VAEVQ-GAN
Recon lossL2 / L1 pixelL1 pixel + LPIPS perceptual + PatchGAN adversarial
PriorPixelCNNTransformer (decoder-only) over code tokens
Codebook512-1024 codes1024-16384 codes
Compression4×-8×8×-32× (higher compression, relying on perceptual + adversarial to save quality)
Applicationunconditional / class-cond generationhigh-res image synthesis, "Taming Transformers"

7.2 Loss formula

$$\mathcal{L}_\text{VQ-GAN}^\text{stage1} = \mathcal{L}_\text{rec} + \mathcal{L}_\text{VQ} + \lambda \cdot \mathcal{L}_\text{GAN}$$

where:

$$ \begin{aligned} \mathcal{L}_\text{rec} &= \|x - \hat{x}\|_1 + \mathcal{L}_\text{LPIPS}(x, \hat{x}) \\ \mathcal{L}_\text{VQ} &= \|\text{sg}[z_e] - e\|^2 + \beta \|z_e - \text{sg}[e]\|^2 \end{aligned} $$

GAN term in generator/tokenizer stage (only backprops to generator's output, discriminator is updated separately in another stage):

$$\mathcal{L}_\text{GAN}^{(G)} = -\mathbb{E}_{\hat{x}}[\log D(\hat{x})]\quad\text{(non-saturating)}\quad\text{or}\quad \mathcal{L}_\text{GAN}^{(G)} = -\mathbb{E}_{\hat{x}}[D(\hat{x})]\quad\text{(hinge)}$$

Discriminator's own minimax term (independent step updating $D$):

$$\mathcal{L}_\text{GAN}^{(D)} = -\mathbb{E}_x[\min(0, -1+D(x))] - \mathbb{E}_{\hat{x}}[\min(0, -1-D(\hat{x}))]\quad\text{(hinge)}$$

Adaptive $\lambda$ (paper novelty, using gradient norm ratio of the last layer for auto-balancing, avoiding manual tuning):

$$\lambda = \frac{\lVert\nabla_{G_L} \mathcal{L}_\text{rec}\rVert}{\lVert\nabla_{G_L} \mathcal{L}_\text{GAN}^{(G)}\rVert + \delta}$$

$G_L$ is the last layer of the decoder; $\lVert\cdot\rVert$ is the Frobenius norm. Total generator loss:

$$\mathcal{L}_G = \mathcal{L}_\text{rec} + \mathcal{L}_\text{VQ} + \lambda \cdot \mathcal{L}_\text{GAN}^{(G)}$$

7.3 Stage 2: Transformer Prior

Stage 1 trains the VQ-GAN, converting the image into a token grid $\mathbf{c} = (c_1, \ldots, c_{HW})$ (raster-scan flattened). Stage 2 trains a decoder-only Transformer on the token sequence, standard AR:

$$p(\mathbf{c}) = \prod_{i=1}^{HW} p(c_i | c_{

Sampling: AR sample tokens → VQ-GAN decoder → image. This is the standard paradigm of translating "image generation" into "language model"; DALL·E / Parti / Muse all evolve from this idea.

VQ-GAN's role in LDM

Stable Diffusion's VAE tokenizer is actually the continuous-latent variant of a KL-regularized VQ-GAN (drops quantization, keeps KL + perceptual + adversarial), output is a continuous latent map (4 channels, downsampled 8×). Diffusion runs in this latent, and finally decoder restores it. Can be understood as "VQ-GAN encoder/decoder + continuous latent + KL."

7.4 PatchGAN Discriminator (production architecture)

VQ-GAN uses PatchGAN (Isola et al. CVPR 2017 "pix2pix"):

class PatchDiscriminator(nn.Module):
    """ PatchGAN: stack of strided convs with 70×70 receptive field
        Output [B, 1, H/8, W/8] patch-level real/fake decision """
    def __init__(self, in_ch=3, hidden=64, n_layers=3):
        super().__init__()
        layers = [nn.Conv2d(in_ch, hidden, 4, 2, 1), nn.LeakyReLU(0.2, True)]
        ch = hidden
        for i in range(1, n_layers):
            ch_next = min(hidden * (2 ** i), 512)
            layers += [
                nn.Conv2d(ch, ch_next, 4, 2, 1),
                nn.BatchNorm2d(ch_next),
                nn.LeakyReLU(0.2, True),
            ]
            ch = ch_next
        layers += [
            nn.Conv2d(ch, ch * 2, 4, 1, 1),
            nn.BatchNorm2d(ch * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch * 2, 1, 4, 1, 1),
        ]
        self.net = nn.Sequential(*layers)

    def forward(self, x): return self.net(x)


def hinge_d_loss(real_logits, fake_logits):
    real = F.relu(1.0 - real_logits).mean()
    fake = F.relu(1.0 + fake_logits).mean()
    return 0.5 * (real + fake)

def hinge_g_loss(fake_logits):
    return -fake_logits.mean()
GAN training trick list

7.5 LPIPS (Perceptual Loss)

$$\mathcal{L}_\text{LPIPS}(x, \hat{x}) = \sum_l w_l \cdot \|\phi_l(x) - \phi_l(\hat{x})\|^2$$

$\phi_l$ is the $l$-th layer feature map of pretrained VGG / AlexNet, $w_l$ is the learned channel-wise weight (Zhang et al. CVPR 2018). Closer to human perception than pixel-MSE; standard for VQ-GAN / SD / most image GAN / diffusion training.

§8 Discrete VAE and Gumbel-Softmax

8.1 dVAE (DALL·E 1, Ramesh et al. ICML 2021)

DALL·E uses dVAE (discrete VAE) as image tokenizer:

8.2 Gumbel-Softmax derivation

Goal: make categorical sampling differentiable.

Gumbel-Max trick: for logits $\pi = (\pi_1, \ldots, \pi_K)$ add independent Gumbel(0,1) noise $g_k = -\log(-\log u_k), u_k \sim \mathcal{U}(0, 1)$, then:

$$\arg\max_k \{\log \pi_k + g_k\}$$

follows categorical(softmax($\pi$)). Proof uses Gumbel CDF property: $P(\max_k X_k = X_j) = e^{\pi_j} / \sum_k e^{\pi_k}$.

Gumbel-softmax (Jang et al., ICLR 2017; Maddison et al., ICLR 2017 concurrent): replace the non-differentiable argmax with softmax:

$$\boxed{\;y_k = \frac{\exp((\log \pi_k + g_k) / \tau)}{\sum_j \exp((\log \pi_j + g_j) / \tau)}\;}$$

Straight-Through Gumbel-Softmax: forward uses argmax (discrete), backward uses softmax gradient — same idea as VQ-VAE's STE.

def gumbel_softmax_sample(logits, tau=1.0, hard=False, dim=-1):
    """ Input logits = log π   Output soft / hard one-hot """
    # 1) Add Gumbel noise
    g = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
    y_soft = F.softmax((logits + g) / tau, dim=dim)
    if not hard:
        return y_soft
    # ST: forward hard, backward soft gradient
    index = y_soft.argmax(dim=dim, keepdim=True)
    y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
    y = y_hard - y_soft.detach() + y_soft   # straight-through
    return y
VQ-VAE vs Gumbel-softmax / dVAE

Both are discrete latent models, differences:

8.3 MaskGIT (Chang et al., CVPR 2022)

Replace AR Transformer prior with BERT-style masked Transformer:

Successor: MUSE (Chang et al., 2023) extends the same idea to text-to-image, one of Google's main generative models.

§9 FSQ: Finite Scalar Quantization (focus)

9.1 Motivation (Mentzer, Minnen, Agustsson, Toderici, ICLR 2024)

VQ-VAE has persistent problems:

  1. Codebook collapse / underuse: most codes never used (dead codes), perplexity far below theoretical $K$
  2. STE bias: gradient estimate is biased, training unstable
  3. Complex loss balancing: commitment weight, EMA decay, dead code revival all need tuning
  4. Effective codebook size ceiling: practical limit ~$10^3$-$10^4$, larger fails

FSQ bypasses this with one trick: per-dimension scalar quantization (scalar quantization, not vector quantization).

9.2 Core formula (must derive)

Let encoder output $z \in \mathbb{R}^d$. Independently for each dim, do scalar quantization (FSQ paper Eq. 4):

$$z_i \longrightarrow z'_i = \tfrac{L_i-1}{2}\tanh(z_i) - s_i \longrightarrow \hat{z}_i = \text{round}(z'_i) + s_i$$

where $s_i = 0$ if $L_i$ is odd, $s_i = 0.5$ if $L_i$ is even. So:

Regardless of parity, each dim has $L_i$ levels; multiply the level counts of all dimensions:

$$\boxed{\;K_\text{implicit} = \prod_{i=1}^{d} L_i\;}$$

Implicit codebook size examples

9.3 Why doesn't FSQ have codebook collapse?

Key insight

In VQ-VAE, the root cause of codebook collapse is: the codebook is a free parameter, the optimizer lets most $e_k$ drift to useless regions, and only a few $e_k$ are used repeatedly. FSQ's "codebook" is not a parameter — it is fixed grid points on the number axis ($\{-L/2, \ldots, L/2\}$).

Empirical: FSQ's codebook usage is nearly 100% (compared to VQ-VAE's 50-70%), this conclusion is reproduced on ImageNet / Cosmos / OpenMagViT2.

9.4 Why doesn't FSQ need "explicit STE wrapping" and why is its loss minimal

$$\mathcal{L}_\text{FSQ} = \|x - \hat{x}\|^2 \quad \text{(plus optional perceptual + adversarial)}$$

Simplified VQ-VAE vs FSQ comparison

FSQ "trades spatial dimensions for codebook size": VQ-VAE uses 1 dim ($D$ continuous values + 1 discrete choice from $K$), FSQ uses $d$ independent discrete dims with $L_i$ levels each; the final discrete entropy is actually larger and collapse nearly impossible. Cost: embedding expressiveness is slightly weaker (each dim independent, not sharing representation), but reconstruction-side decoder compensates.

9.5 FSQ implementation (10 lines)

class FSQ(nn.Module):
    """ Finite Scalar Quantization (Mentzer et al., ICLR 2024)
        levels: tuple, number of quantization levels per dim (must be all odd or all even, odd guarantees 0 included)
        eps:    bounding safety margin, avoiding round jumping out of grid after tanh """

    def __init__(self, levels=(8, 5, 5, 5)):
        super().__init__()
        levels_t = torch.tensor(levels, dtype=torch.float32)
        self.levels = levels_t
        self.d = len(levels)
        self.K = int(torch.prod(levels_t).item())            # implicit codebook size = ∏ L_i
        # FSQ paper Eq. 4: half = (L-1)/2; shift = 0.5 if L even else 0
        half = (levels_t - 1) / 2                            # [d]
        shift = ((levels_t % 2) == 0).float() * 0.5          # [d]
        self.register_buffer("half_l", half)
        self.register_buffer("shift", shift)
        # mixed-radix basis for token id encoding
        cumprod = torch.tensor([1.0] + list(torch.cumprod(levels_t[:-1], dim=0)),
                               dtype=torch.float32)
        self.register_buffer("basis", cumprod)               # [d]

    @staticmethod
    def round_ste(z):
        """STE for non-differentiable round: forward round, backward identity"""
        return z + (z.round() - z).detach()

    def forward(self, z):
        """ z: [B, d, ...]  ->  z_hat: [B, d, ...] (quantized values), codes: [B, ...] (∈ 0..K-1) """
        view = (1, -1) + (1,) * (z.dim() - 2)
        half = self.half_l.view(*view).to(z.device)
        shift = self.shift.view(*view).to(z.device)
        # 1) Bound: tanh(z) * half - shift  → z'∈[-half-shift, half-shift]
        z_bounded = torch.tanh(z) * half - shift
        # 2) Round (STE) + add back shift → odd L gives {-half,…,half} (integers), even L gives {-half,…,half} (half-integers)
        z_hat = self.round_ste(z_bounded) + shift
        # 3) Token ID (mixed-radix): map each d-dim ∈ {-half_i,…,half_i} to 0..L_i-1 then encode as single index
        shifted = (z_hat + half).round().long()              # ∈ 0..L_i-1 (round to handle floating-point error)
        basis = self.basis.view(*view).to(z.device).long()
        codes = (shifted * basis).sum(dim=1)                 # [B, ...]
        return z_hat, codes


# Usage example:
# fsq = FSQ(levels=(8, 5, 5, 5))    # K = 8·5·5·5 = 1000
# z = encoder(x)                     # [B, 4, H, W]
# z_hat, tokens = fsq(z)             # z_hat: [B, 4, H, W], tokens: [B, H, W] ∈ 0..999
# x_hat = decoder(z_hat)
# loss = F.mse_loss(x_hat, x)        # That's the only term!
FSQ level selection experience

9.6 LFQ: Lookup-Free Quantization (MAGVIT-v2, Yu et al., ICLR 2024)

A binary special case of FSQ:

$$\text{LFQ}(z) = \text{sign}(z) \in \{-1, +1\}^d$$

Only 2 levels per dim, implicit codebook = $2^d$: with $d=18$, codebook = $2^{18} = 262{144}$ (same order as FSQ-equivalent).

Features:

class LFQ(nn.Module):
    """ Lookup-Free Quantization (MAGVIT-v2)
        Per-dim independent sign quantize, implicit codebook = 2^d """
    def __init__(self, dim: int, entropy_weight: float = 0.1):
        super().__init__()
        self.d = dim
        self.K = 2 ** dim
        self.entropy_weight = entropy_weight

    def forward(self, z):
        # z: [B, d, ...]
        q = torch.sign(z)
        # Avoid sign(0) = 0
        q = torch.where(q == 0, torch.ones_like(q), q)
        # STE
        z_hat = z + (q - z).detach()

        # Entropy regularization (prevents some dim from always having same sign)
        # p_+ = sigmoid(z), p_- = 1 - p_+
        if self.training:
            p = torch.sigmoid(z)
            per_dim_entropy = -(p * torch.log(p + 1e-9)
                                + (1 - p) * torch.log(1 - p + 1e-9))
            entropy_loss = -per_dim_entropy.mean()    # maximize entropy → minimize -H
        else:
            entropy_loss = z.new_tensor(0.0)

        return z_hat, self.entropy_weight * entropy_loss

9.7 Cosmos / OpenMagViT2 / modern video tokenizers

TokenizerSourceQuantizationUsed in
MAGVIT-v2Google 2024 (ICLR)LFQEarly text-to-video demo
OpenMagViT2Open reproduction 2024LFQPublic video tokenizer baseline
Cosmos TokenizerNVIDIA 2024FSQ + video spatiotemporal compressionNVIDIA Cosmos world model
VideoPoet tokenizerGoogle 2024LFQ-styletext-to-video

Engineering essentials:

§10 Complexity and Resource Comparison

ModelLatent typeTrain parameters (encoder+decoder)Main lossCodebook collapseSTE dependency
VAEcontinuous Gaussian$\sim$10-100Mrecon + KL (closed form)N/ANo (reparameterization)
$\beta$-VAEcontinuous GaussianSame as VAErecon + $\beta$·KLN/ANo
NVAEhierarchical continuous80M-200Mrecon + multi-layer KLN/ANo
VQ-VAEdiscrete via codebook50-200Mrecon + codebook + $\beta$·commitmentFrequently occursYes
VQ-VAE-2hierarchical discrete100-500MSame as VQ-VAE × 2 layersSame as aboveYes
VQ-GANdiscrete + adversarial50-300M (+ D)recon + LPIPS + GAN + codebook + commitmentSame as aboveYes
dVAEcategorical (logits)50-200Mrecon + KL to uniformRare (categorical distribution learning)No (Gumbel-softmax backprops soft)
FSQscalar quantize per dim30-150Mrecon (+ perceptual)Almost never occursYes (but minimal)
LFQbinary scalar quantize30-150Mrecon (+ entropy reg)Almost never occursYes
Ecological niche positioning

11.1 VAE vs GAN vs Diffusion vs Flow / FM

ModelLikelihoodTraining stabilityDiversitySample qualityInference speed
VAEYes (ELBO)✅ Stable✅ Good⚠️ Blurry✅ 1-step
GANNone❌ Hard❌ Mode collapse✅ Sharp✅ 1-step
DiffusionApproximate (VLB)✅ Stable✅ Good✅ SOTA❌ Many NFE
Flow / FMYes (ODE)✅ Stable✅ Good✅ Strong⚠️ Several NFE

11.2 Role of tokenizer series in large models


Tokenizer Stage 1                Generative Stage 2 (prior)
────────────────                 ──────────────────────────
VQ-GAN  →   discrete token grid  →   Transformer AR  (Parti, DALL·E 1, Cogview)
VQ-GAN  →   discrete token grid  →   Masked Transformer (MaskGIT, Muse)
FSQ    →   discrete token grid  →   Transformer AR  (Cosmos, OpenMagViT2)
LFQ    →   binary token grid →   AR / bit predictor (MAGVIT-v2, VideoPoet)
KL-VAE →   continuous latent map  →   Diffusion / Flow Matching (LDM, SD, SD3, FLUX)

11.3 Reconstruction-Perception Tradeoff (advanced question)

Blau & Michaeli (ICML 2018) proved: between reconstruction (MSE / PSNR) and perception (perceptual / FID) there is a strict Pareto boundary. VQ-GAN / SD VAE introduces LPIPS + adversarial to trade higher perceptual quality for slightly worse PSNR.

PSNR doesn't equal "looks good"

VQ-GAN paper's PSNR is not necessarily better than VQ-VAE, but perceptual (LPIPS / FID) is much better. In interviews, often asked "why does SOTA tokenizer have lower PSNR" — this is distortion-perception tradeoff.

§12 25 Frequently-Asked Interview Questions

Listed from the perspective of a top-lab interviewer by codex (gpt-5.5 xhigh), divided into 3 tiers by difficulty. Each question expands to answer points + common pitfalls.

L1 Must-Know (any ML engineering position will ask)

Q1. What is the ELBO of VAE? Write the formula.
  • $\log p(x) \geq \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}(q_\phi(z|x)\,\|\,p(z))$
  • First term: reconstruction expected log-likelihood
  • Second term: KL pulls posterior toward prior $\mathcal{N}(0, I)$
  • Gap between ELBO and $\log p(x)$ = $D_\text{KL}(q_\phi(z|x)\,\|\,p_\theta(z|x))$

Pitfalls: writing $\log p(x|z) - D_\text{KL}(...)$ (missing the expectation symbol); only saying "recon + regularization" without writing the formula.

Q2. What problem does the reparameterization trick solve?
  • Direct sampling $z \sim q_\phi(z|x)$ is non-differentiable, gradient cannot reach encoder
  • Rewrite as $z = \mu + \sigma \odot \epsilon, \epsilon \sim \mathcal{N}(0, I)$, moving randomness to independent noise
  • $z$ becomes a deterministic function of $\phi$, can backprop normally
  • Not just Gaussian; Gumbel-softmax is the same idea

Pitfalls: saying "for speed-up" — actually it's for differentiability.

Q3. What does $\beta$-VAE's $\beta$ control?
  • $\beta > 1$: stronger KL regularization, encourages disentangled latent
  • $\beta < 1$: relaxes KL, more precise recon but worse prior fit
  • $\beta = 1$: standard VAE
  • But Locatello 2019 proved: pure unsupervised disentanglement is infeasible, needs inductive bias

Pitfalls: only saying "larger $\beta$ → more disentanglement" — wrong, depends on data + architecture.

Q4. What is posterior collapse?
  • During training KL → 0, i.e. $q_\phi(z|x) \approx p(z)$ (independent of $x$)
  • Decoder completely ignores $z$, VAE degenerates to unconditional model
  • Common with strong AR decoders (e.g. PixelCNN / LSTM) + simple data

Pitfalls: only saying "latent is useless," not mentioning the quantitative indicator KL → 0.

Q5. What is VQ-VAE's codebook? How is it used?
  • $\{e_1, \ldots, e_K\} \subset \mathbb{R}^D$ a set of learnable "codebook" vectors
  • Encoder outputs continuous $z_e(x)$; use nearest neighbor $k = \arg\min \|z_e - e_k\|^2$ to replace with $z_q = e_k$
  • Decoder decodes $z_q$ back to pixels
  • Trained via codebook loss (pulls $e$ to $z_e$) + commitment loss (pulls $z_e$ to $e$)

Pitfalls: saying codebook is fixed / pretrained — wrong, learned end-to-end.

Q6. What are VQ-VAE's three loss terms?
  • Reconstruction: $\|x - \hat{x}\|^2$ (pixel-level)
  • Codebook loss: $\|\text{sg}[z_e] - e\|^2$ (only updates $e$)
  • Commitment loss: $\beta \|z_e - \text{sg}[e]\|^2$ (only updates encoder, $\beta = 0.25$)
  • sg = stop_gradient, avoids gradient coupling and oscillation between both sides

Pitfalls: mixing codebook and commitment loss into one; forgetting the direction of sg.

Q7. What is the Straight-Through Estimator (STE)?
  • Solves the backprop problem of non-differentiable operations (e.g. argmax / round)
  • Forward uses discrete output, backward directly passes gradient to "the previous continuous input layer"
  • Equivalent surrogate: assume quantization layer is identity map
  • PyTorch three lines: z_q = z_e + (z_q_quantized - z_e).detach()

Pitfalls: only saying "backward uses identity," not saying whether forward is real quantization.

Q8. What does VQ-GAN have over VQ-VAE?
  • LPIPS perceptual loss (replaces / supplements L2)
  • PatchGAN adversarial loss + adaptive $\lambda$ weight
  • Transformer prior (replaces PixelCNN)
  • Larger codebook (1k → 16k) + higher compression (8× → 16-32×)
  • It is "Taming Transformers" (Esser et al. CVPR 2021)

Pitfalls: only saying "GAN" without perceptual; or forgetting Transformer prior.

Q9. What is FSQ? Why doesn't it have codebook collapse?
  • Per-dim independent scalar quantize to $L$ fixed levels ($\tanh$ → scale → round)
  • Implicit codebook = $\prod L_i$ (e.g. $L=8, d=6 \Rightarrow 8^6$)
  • Codebook is not a learnable parameter → nothing to "collapse" to useless region
  • Encoder naturally explores the full grid via reconstruction pressure

Pitfalls: treating FSQ as a VQ-VAE codebook optimization trick — wrong, FSQ has no explicit codebook parameter.

Q10. Write the closed form of KL($\mathcal{N}(\mu, \sigma^2 I) \,\|\, \mathcal{N}(0, I)$).
  • $D_\text{KL} = \tfrac{1}{2}\sum_i (\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1)$
  • Only the diagonal covariance has this simple form
  • In implementation, encoder outputs $\log \sigma^2$ (logvar) for stability
  • kl = 0.5 * (mu**2 + logvar.exp() - logvar - 1).sum()

Pitfalls: confusing $\log \sigma^2$ with $\log \sigma$ (off by 2); or forgetting the "-1" term.

L2 Advanced (research-oriented positions)

Q11. Is IWAE tighter than ELBO? How to use it?
  • $K$ importance samples: $\mathcal{L}_K^\text{IWAE} = \mathbb{E}_{z_1,\ldots,z_K}[\log \tfrac{1}{K} \sum_k \tfrac{p(x, z_k)}{q(z_k|x)}]$
  • $\mathcal{L}_1 = $ ELBO (special case)
  • $\mathcal{L}_K \to \log p(x)$ as $K \to \infty$ (Burda et al. ICLR 2016)
  • But encoder's learned posterior no longer pursues approximation of the true posterior

Pitfalls: saying "$K = 1$ is stronger than ELBO" — wrong, the special case is ELBO.

Q12. How to mitigate posterior collapse? List at least 4.
  • KL annealing: $\beta(t) = \min(1, t/T)$ linear growth (Bowman 2016)
  • Free bits: per-dim KL lower bound $\lambda$ nats (Kingma 2016)
  • Weakened decoder: restrict decoder expressiveness (Chen 2017)
  • VQ-VAE / discrete latent: structurally no KL term, bypass
  • Others: auxiliary loss, word dropout, NF posterior, skip connections

Pitfalls: only answering "KL annealing" alone; or treating $\beta$-VAE as a collapse mitigation tool (actually $\beta > 1$ makes collapse easier).

Q13. Derive the closed form of KL($\mathcal{N}(\mu, \sigma^2) \,\|\, \mathcal{N}(0, 1))$.
  • $\log \tfrac{q}{p} = -\tfrac{1}{2}\log \sigma^2 - \tfrac{(z-\mu)^2}{2\sigma^2} + \tfrac{z^2}{2}$
  • Take $\mathbb{E}_q$ (using $\mathbb{E}[z] = \mu$, $\mathbb{E}[z^2] = \mu^2 + \sigma^2$, $\mathbb{E}[(z-\mu)^2] = \sigma^2$)
  • Result: $\tfrac{1}{2}(\mu^2 + \sigma^2 - \log \sigma^2 - 1)$
  • Multi-dim independent case: sum over dims

Pitfalls: skipping the derivation and just memorizing the formula; forgetting the expansion of $\mathbb{E}_q[z^2]$.

Q14. Role of sg in VQ-VAE's codebook vs commitment terms.
  • Codebook loss $\|\text{sg}[z_e] - e\|^2$: gradient only updates $e$ (codebook vector), not encoder
  • Commitment loss $\|z_e - \text{sg}[e]\|^2$: gradient only updates encoder, not $e$
  • Two sg's decouple bidirectional alignment, avoiding mutual interference
  • Without both sg's, it's equivalent to ordinary MSE; effectively the learning rate doubles + both sides pull each other

Pitfalls: thinking sg is to "prevent codebook updates from going too fast" — actually it's for gradient decoupling.

Q15. STE's gradient is equivalent to what kind of surrogate?
  • STE = "forward real quantization, backward identity surrogate"
  • Equivalent to setting the differentiable surrogate of $z_q = \text{quantize}(z_e)$ to $z_q^\text{surr} = z_e$
  • I.e. assume quantization is the identity map
  • It is a biased estimate (biased gradient), but with low variance and works in practice
  • Rigorous analysis: Bengio et al. (2013) "Estimating or Propagating Gradients Through Stochastic Neurons"

Pitfalls: saying STE is unbiased — wrong, it is biased.

Q16. What is the EMA codebook update formula? Why use it?
  • $N_k^{(t)} = \gamma N_k^{(t-1)} + (1-\gamma) n_k^{(t)}$, cluster count
  • $m_k^{(t)} = \gamma m_k^{(t-1)} + (1-\gamma) \sum_{i \to k} z_{e,i}$, cluster vector sum
  • $e_k^{(t)} = m_k^{(t)} / (N_k^{(t)} + \varepsilon)$
  • Pros: codebook updates more stable; can periodically revive dead codes
  • $\gamma \approx 0.99, \varepsilon \approx 10^{-5}$ are common values

Pitfalls: treating EMA as a momentum + Adam SGD variant — essentially it's k-means EMA estimation under mini-batches.

Q17. What is PatchGAN? Why use it in VQ-GAN?
  • Does not output a single scalar but an N×N patch-level real/fake map
  • Each patch is 70×70 receptive field (using stack of strided convs)
  • Suitable for capturing local texture realness, low pressure on global structure
  • Lets the generator focus on texture details rather than full-image discrimination (VQ-GAN's global relies on recon + LPIPS)
  • From Isola et al. CVPR 2017 "pix2pix"

Pitfalls: thinking PatchGAN is attention-based; or saying it's only used in image-to-image translation.

Q18. What is LPIPS? Advantages over MSE?
  • Compute distance using intermediate-layer features of pretrained VGG / AlexNet: $\sum_l w_l \|\phi_l(x) - \phi_l(\hat{x})\|^2$
  • $w_l$ is the learned channel-wise weight (Zhang et al. CVPR 2018)
  • Closer to human perception than pixel-MSE
  • Standard for VQ-GAN / SD / most image GAN / diffusion training
  • Used with distortion-perception tradeoff (Blau & Michaeli ICML 2018)

Pitfalls: only saying "uses VGG features," not learned channel weights / fit to human study.

Q19. How does the Gumbel-Max trick approximate categorical sampling?
  • For logits $\pi$ add independent Gumbel noise $g_k = -\log(-\log u_k)$
  • $\arg\max_k(\log \pi_k + g_k)$ follows categorical(softmax($\pi$))
  • Replace argmax with softmax to get Gumbel-softmax, differentiable
  • Temperature $\tau \to 0$ approaches one-hot; ST version forward argmax / backward softmax gradient
  • Used in dVAE / DALL·E 1

Pitfalls: writing Gumbel(0,1) as normal noise; forgetting that argmax probability is proportional to softmax.

Q20. Where is MaskGIT faster than AR? Why is the quality not worse?
  • Training: BERT-style mask-and-predict (not next-token AR)
  • Sampling: each round parallel unmask a batch of tokens (by confidence ranking), 8-12 rounds to converge
  • About 10x faster than AR, because each round is a parallel forward
  • Quality is not worse because (1) iterative refinement equates to multiple forwards; (2) bidirectional context
  • On ImageNet 256×256, quality comparable to AR; MUSE extends the same idea to text-to-image

Pitfalls: thinking MaskGIT is a discrete diffusion — actually it's a generative extension of BERT MLM.

L3 Advanced Variants (top lab / generative model direction)

Q21. Derive FSQ's implicit codebook size, and why no extra STE-wrapped loss is needed.
  • Per-dim independent quantization to $L_i$ levels: $z_i \to \tanh(z_i) \cdot (L_i-1)/2 \to \text{round}$
  • Combined across $d$ dims: implicit codebook = $\prod_i L_i$
  • Example: $L = (8, 5, 5, 5), d = 4 \Rightarrow K = 1000$
  • No codebook parameter → no codebook collapse; no explicit codebook loss, commitment loss, EMA, dead-code revival
  • Still needs STE for non-differentiable round: z_hat = z + (z.round() - z).detach() one-line solution
  • Loss is only $\|x - \hat{x}\|^2$ (+ optional perceptual + adversarial)

Pitfalls: conflating FSQ with LFQ (LFQ is the $L = 2$ binary special case); or thinking FSQ removes STE (actually round still needs STE, just no extra codebook / commitment loss).

Q22. How to diagnose + mitigate codebook collapse in VQ-VAE?
  • Diagnose: measure perplexity = $\exp(-\sum_k p_k \log p_k)$, where $p_k$ is the usage frequency of the $k$-th code in codebook

    • Healthy perplexity should approach $K$ (uniform usage upper bound)
    • In practice perplexity / K < 50% is common, with some codes nearly unused
  • Mitigate:

    • EMA codebook update (basic)
    • Dead code revival: every $T$ steps, reset $e_k$ with $N_k < \tau$ to some random $z_e$ from the current batch
    • k-means init: before training do k-means initialization of codebook using $z_e$ from the first batch
    • Code dropout: during training randomly drop a portion of the codebook, forcing the following stages to not depend on single codes
    • Switch to FSQ / LFQ: structurally avoid (the simplest "mitigation")

Pitfalls: only answering "use a larger codebook" — wrong, larger codebook is actually more prone to collapse.

Q23. The respective roles of VAE / VQ-VAE / Diffusion / FM in the LDM series?
  • VAE (KL-regularized VAE / VQ-GAN-without-quant): image $\to$ continuous latent map (Stable Diffusion uses 8× downsample, 4-channel latent)
  • VQ-VAE / VQ-GAN: image $\to$ discrete token grid, for AR / MaskGIT prior (Parti / DALL·E / Muse / Cosmos)
  • Diffusion / FM prior: runs reverse process in VAE latent space (LDM / SD / SDXL / SD3 / FLUX)
  • AR / Masked Transformer prior: runs on VQ tokens (Parti / Muse / VideoPoet)
  • Key insight: tokenizer and prior are two stages, tokenizer is frozen after training

Pitfalls: confusing SD's VAE with VQ-VAE — SD's VAE has no quantization.

Q24. How does NVAE stably train a hierarchical VAE? What are the key tricks?
  • Residual normal parameterization: $q(z_l|\cdot) = \mathcal{N}(\mu_p + \Delta\mu_q,\, \sigma_p \cdot \Delta\sigma_q)$, letting posterior be a small perturbation of prior
  • Spectral regularization: controls per-layer Lipschitz constant, avoiding numerical instability
  • Architecture tuning with BatchNorm + Swish + depthwise etc.
  • Per-layer independent free bits, avoiding high-layer collapse
  • Warm-up KL: lower layers trained first, higher layers introduced later
  • Vahdat & Kautz, NeurIPS 2020

Pitfalls: only answering "uses ResNet architecture," without mentioning the probabilistic-level residual normal.

Q25. What is the reconstruction-perception tradeoff? What does it imply for VQ-GAN / SD VAE?
  • Blau & Michaeli (ICML 2018) proved: between MSE / PSNR (distortion) and perceptual distance (perception) there is a strict Pareto boundary
  • Lower distortion → necessarily raises or maintains perception loss, and vice versa
  • No simultaneous optimum: VQ-GAN / SD VAE introducing LPIPS + adversarial is actively sacrificing PSNR to gain perceptual quality
  • Implication: evaluating tokenizers should not only look at PSNR / MSE; FID / IS / KID and other perceptual metrics are more important
  • Industrial practice: at 8×-32× high compression, perceptual loss is the key to VQ-GAN / SD VAE not being blurry

Pitfalls: only answering "PSNR is not a good metric," without saying the underlying strict Pareto bound.

§A Appendix: Complete from-scratch code skeleton + sanity check

The reference from-scratch implementation includes:

Actual sanity-check output (PyTorch 2.x, single-machine GPU):

[a] VAE(MNIST 784→20):   recon=78.4   KL=18.6   loss=97.0    ✓
[b] reparam grad path:   dL/dμ ≠ 0, dL/dlogvar ≠ 0           ✓
[c] VQ-VAE(64×64×3):     recon=0.012  vq=0.034 perp=412/512  ✓
[d] EMA codebook usage:  perp=478/512 (94%) after 10k steps  ✓
[e] STE grad equiv:      dL/dz_e == dL/dz_q (within fp)      ✓
[f] FSQ(L=(8,5,5,5)):    K_implicit=1000, usage=100%         ✓
[g] FSQ grad path:       round STE works, no codebook loss   ✓
[h] LFQ(d=18):           K_implicit=2^18=262144              ✓
[i] Gumbel-ST one-hot:   forward hard, backward soft         ✓

Code has passed independent reviewer static checks + PyTorch sanity checks:

VAE / VQ-VAE / VQ-GAN / FSQ Quick Reference · Main references: Kingma & Welling 2014 (VAE), Higgins et al. 2017 ($\beta$-VAE), van den Oord et al. 2017 (VQ-VAE), Razavi et al. 2019 (VQ-VAE-2), Esser et al. 2021 (VQ-GAN), Ramesh et al. 2021 (DALL·E / dVAE), Chang et al. 2022 (MaskGIT), Mentzer et al. 2024 (FSQ), Yu et al. 2024 (MAGVIT-v2 / LFQ)