Interview Prep · Latent Variable Models & Discrete Tokenizers

VAE / VQ-VAE / VQ-GAN / FSQ 面试 Cheat Sheet

公式推导 + From-Scratch 代码 + 25 高频题(L1 必会 · L2 进阶 · L3 顶级 lab)

Source: docs/tutorials/vae_vqvae_vqgan_tutorial.md SHA256: 826ba28b7bb7 Rendered: 2026-05-19 10:57 UTC

§0 TL;DR Cheat Sheet

8 句话搞定 VAE / VQ-VAE / VQ-GAN / FSQ

一页拿下面试核心要点(详见后文 §2–§9 推导)。

  1. 连续 VAE 目标:最大化 ELBO,$\log p(x) \geq \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}(q_\phi(z|x)\,\|\,p(z))$;reparameterization $z = \mu + \sigma \odot \epsilon$ 让梯度穿过随机采样。
  2. KL 闭式(必考):$D_\text{KL}(\mathcal{N}(\mu,\sigma^2 I)\,\|\,\mathcal{N}(0,I)) = \tfrac{1}{2}\sum_i (\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1)$。
  3. Posterior collapse:KL → 0 → 解码器忽略 $z$;缓解:KL annealing、free bits、$\beta$ schedule、自回归先验。
  4. VQ-VAE:把 encoder 输出 $z_e(x)$ 映射到最近邻 codebook 向量 $e_k$,loss = recon + $\|\text{sg}[z_e] - e\|^2$(codebook)$+ \beta \|z_e - \text{sg}[e]\|^2$(commitment)。
  5. Straight-Through Estimator (STE):argmin / quantize 不可导,前向用量化值,反向直通梯度 $\partial \mathcal{L}/\partial z_q \to \partial \mathcal{L}/\partial z_e$。
  6. VQ-GAN:VQ-VAE + perceptual (LPIPS) + adversarial (PatchGAN) + 后训 Transformer prior;为 LDM / Parti / Muse 等离散 token 模型奠基。
  7. FSQ(2024):每维标量量化到 $\{-L,\ldots,L\}$,隐式 codebook 大小 $\prod_i L_i$(如 $L=8, d=6 \Rightarrow 8^6 = 262{144}$),无需 STE、不会 codebook collapse,rounding 用 STE 即可,loss 只剩 reconstruction。
  8. 生态对比:连续 latent(VAE / KL)适合 LDM 续 diffusion;离散 token(VQ-VAE / VQ-GAN / FSQ / LFQ)适合 AR / MaskGIT Transformer prior,是 Parti / Muse / Cosmos 等的核心组件。

§1 直觉:为什么要 latent variable model

生成模型的核心难题:直接建模 $p(x)$ 很难,但如果引入低维 latent $z$:

$$p(x) = \int p(x|z)\, p(z)\, dz$$

可以把"复杂的图像分布"分解为"简单先验 $p(z)$(如 $\mathcal{N}(0, I)$)"加上"易学的条件分布 $p(x|z)$"。两条路:

训练 vs 推理的不对称

VAE/VQ-VAE 训练时学整套 encoder + decoder(rate-distortion 角度:"压缩-重建");推理时根据应用分两种:

§2 VAE:核心公式与推导

2.1 ELBO 推导(必考、要会逐行推)

模型族 $p_\theta(x, z) = p_\theta(x|z)\, p(z)$,先验 $p(z) = \mathcal{N}(0, I)$,似然 $p_\theta(x|z)$ 由 decoder 给出。Marginal likelihood:

$$\log p_\theta(x) = \log \int p_\theta(x|z) p(z)\, dz$$

任意 distribution $q_\phi(z|x)$(encoder / variational posterior,$q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \mathrm{diag}(\sigma_\phi^2(x)))$),由 Jensen 不等式 / 直接代入:

$$ \begin{aligned} \log p_\theta(x) &= \log \int q_\phi(z|x) \frac{p_\theta(x|z) p(z)}{q_\phi(z|x)} dz \\ &\geq \mathbb{E}_{q_\phi(z|x)}\!\left[\log \frac{p_\theta(x|z) p(z)}{q_\phi(z|x)}\right] \quad \text{(Jensen)} \\ &= \underbrace{\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]}_{\text{reconstruction (negative)}} - \underbrace{D_\text{KL}(q_\phi(z|x)\,\|\,p(z))}_{\text{regularization}} \end{aligned} $$

故 ELBO:

$$\boxed{\;\mathcal{L}_\text{ELBO}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}\!\big(q_\phi(z|x)\,\|\,p(z)\big)\;}$$

Tight 的代价:ELBO 与真 log-likelihood 的 gap 等于 $D_\text{KL}(q_\phi(z|x)\,\|\,p_\theta(z|x))$。posterior $q_\phi$ 越逼近真后验 $p_\theta(z|x)$,gap 越小。

2.2 KL 项闭式推导(L3 必推)

设 $q_\phi(z|x) = \mathcal{N}(\mu, \mathrm{diag}(\sigma^2))$(对角协方差,每维 $\sigma_i^2$),$p(z) = \mathcal{N}(0, I)$。

对每一维 $i$ 独立:

$$ \begin{aligned} D_\text{KL}(\mathcal{N}(\mu_i, \sigma_i^2) \,\|\, \mathcal{N}(0, 1)) &= \int \mathcal{N}(z; \mu_i, \sigma_i^2) \log \frac{\mathcal{N}(z; \mu_i, \sigma_i^2)}{\mathcal{N}(z; 0, 1)} dz \end{aligned} $$

展开两个 Gaussian 密度的 log:

$$ \log \frac{\mathcal{N}(z; \mu_i, \sigma_i^2)}{\mathcal{N}(z; 0, 1)} = -\tfrac{1}{2}\log \sigma_i^2 - \tfrac{(z-\mu_i)^2}{2\sigma_i^2} + \tfrac{z^2}{2} $$

求期望(利用 $\mathbb{E}_q[z] = \mu_i$, $\mathbb{E}_q[z^2] = \mu_i^2 + \sigma_i^2$, $\mathbb{E}_q[(z-\mu_i)^2] = \sigma_i^2$):

$$ \begin{aligned} D_\text{KL} &= -\tfrac{1}{2}\log \sigma_i^2 - \tfrac{1}{2} + \tfrac{1}{2}(\mu_i^2 + \sigma_i^2) \\ &= \tfrac{1}{2}\big(\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1\big) \end{aligned} $$

对所有维度求和:

$$\boxed{\;D_\text{KL}\big(\mathcal{N}(\mu, \mathrm{diag}(\sigma^2)) \,\|\, \mathcal{N}(0, I)\big) = \tfrac{1}{2}\sum_{i=1}^{d}\big(\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1\big)\;}$$

数值稳定

实现时让 encoder 输出 $\log \sigma^2$(log-variance)而非 $\sigma$,避免 $\sigma$ 取 exp 后 overflow。代码里写 kl = 0.5 * (mu**2 + logvar.exp() - logvar - 1).sum()

2.3 Reparameterization Trick(必考)

ELBO 中 $\mathbb{E}_{q_\phi(z|x)}[\cdot]$ 用 Monte Carlo 估计:采一个 $z \sim q_\phi(z|x)$,求 $\log p_\theta(x|z)$。

问题:直接采样 $z = \text{sample}(\mathcal{N}(\mu, \sigma^2))$ 是不可导操作,梯度无法回传到 $\phi$。

解法:把随机性挪到独立噪声里:

$$\boxed{\;z = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\;}$$

现在 $z$ 是 $\phi$ 的确定性函数(条件于 $\epsilon$),$\nabla_\phi \mathcal{L}$ 可以正常反向传播。这是 Kingma & Welling (ICLR 2014) 的核心贡献之一。

面试加分:reparameterization 不止 Gaussian

Concrete / Gumbel-softmax(§7)对离散变量做了类似 trick:把 argmax 替换成 softmax 加 Gumbel 噪声,前向近似离散,反向用 softmax 梯度。

2.4 VAE 训练损失(实际写法)

负 ELBO(最小化):

$$\mathcal{L}_\text{VAE}(x) = \underbrace{\|x - \hat{x}\|^2}_{\text{recon (Gaussian likelihood up to const)}} + \underbrace{D_\text{KL}(q_\phi(z|x)\,\|\,p(z))}_{\text{closed form}}$$

对 Bernoulli / Categorical 似然(如 MNIST 二值),recon 项换成 BCE / CE。

§3 完整 VAE 实现(PyTorch)

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class VAE(nn.Module):
    """ 经典 VAE:Gaussian encoder + Gaussian/Bernoulli decoder
        本实现以 MNIST (28×28) 为例,latent dim=20
        生产化把 MLP 换成 ResNet / U-Net encoder/decoder,latent 可以是 spatial map """

    def __init__(self, x_dim: int = 784, h_dim: int = 400, z_dim: int = 20,
                 likelihood: str = "bernoulli"):
        super().__init__()
        self.x_dim, self.z_dim = x_dim, z_dim
        self.likelihood = likelihood

        # Encoder: x -> (μ, logσ²)
        self.enc = nn.Sequential(
            nn.Linear(x_dim, h_dim), nn.ReLU(),
            nn.Linear(h_dim, h_dim), nn.ReLU(),
        )
        self.fc_mu = nn.Linear(h_dim, z_dim)
        self.fc_logvar = nn.Linear(h_dim, z_dim)

        # Decoder: z -> x̂
        self.dec = nn.Sequential(
            nn.Linear(z_dim, h_dim), nn.ReLU(),
            nn.Linear(h_dim, h_dim), nn.ReLU(),
            nn.Linear(h_dim, x_dim),
        )

    def encode(self, x: torch.Tensor):
        h = self.enc(x.view(x.size(0), -1))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor):
        # z = μ + σ ⊙ ε,  σ = exp(0.5 · logvar)
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + std * eps
        else:
            # 推理时用 posterior mean (deterministic)
            return mu

    def decode(self, z: torch.Tensor):
        logits = self.dec(z)
        if self.likelihood == "bernoulli":
            return torch.sigmoid(logits), logits
        return logits, logits  # Gaussian likelihood: 视为 mean prediction

    def forward(self, x: torch.Tensor):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat, logits = self.decode(z)
        return x_hat, logits, mu, logvar


def vae_loss(x: torch.Tensor, logits: torch.Tensor, mu: torch.Tensor,
             logvar: torch.Tensor, likelihood: str = "bernoulli",
             beta: float = 1.0, free_bits: float = 0.0):
    """ Returns:
            (loss, recon, kl)
        beta:        β-VAE 的 β(默认 1 = 标准 VAE)
        free_bits:   每维 KL 下限(nats)。 > 0 时启用 free bits"""
    B = x.size(0)
    x_flat = x.view(B, -1)

    # 1) Reconstruction term: -E_q[log p(x|z)]
    if likelihood == "bernoulli":
        # BCE-with-logits 数值上更稳,等价于 -log Bernoulli likelihood
        recon = F.binary_cross_entropy_with_logits(
            logits, x_flat, reduction="sum") / B
    elif likelihood == "gaussian":
        # 假设 σ² = 1(常数),MSE 与负 log-Gaussian 差一个常数
        recon = 0.5 * F.mse_loss(logits, x_flat, reduction="sum") / B
    else:
        raise ValueError(likelihood)

    # 2) KL term: D_KL(N(μ, σ²) || N(0, I))   闭式
    kl_per_dim = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1)   # [B, z_dim]

    if free_bits > 0:
        # Free bits: 每维 KL 下限 = free_bits(缓解 posterior collapse)
        kl_per_dim = torch.clamp(kl_per_dim, min=free_bits)

    kl = kl_per_dim.sum(dim=-1).mean()                            # 标量

    loss = recon + beta * kl
    return loss, recon, kl
常见 bug 清单

写 VAE 时容易踩的坑。

3.1 训练循环 + KL annealing

def train_vae(model, dataloader, total_steps=50_000, lr=1e-3, device="cuda",
              beta_max=1.0, anneal_steps=10_000, free_bits=0.0):
    """ KL annealing: β 从 0 线性增到 beta_max,防止训练早期 posterior 直接塌缩 """
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device).train()

    step = 0
    while step < total_steps:
        for x, _ in dataloader:
            x = x.to(device)
            beta = min(beta_max, beta_max * step / max(anneal_steps, 1))

            x_hat, logits, mu, logvar = model(x)
            loss, recon, kl = vae_loss(x, logits, mu, logvar,
                                       beta=beta, free_bits=free_bits)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            opt.step()
            step += 1
            if step >= total_steps:
                break

§4 VAE 变体:$\beta$-VAE / IWAE / NVAE / VAE-GAN

4.1 $\beta$-VAE(Higgins et al., ICLR 2017)

把 ELBO 的 KL 加权重 $\beta$:

$$\mathcal{L}_{\beta\text{-VAE}}(\theta, \phi; x) = \mathbb{E}_{q_\phi}[\log p_\theta(x|z)] - \beta \cdot D_\text{KL}(q_\phi(z|x)\,\|\,p(z))$$

disentanglement 争议

Locatello et al. (ICML 2019, 最佳论文) 证明:纯 unsupervised disentanglement 在没有归纳偏置 / 监督的前提下是不可能的。$\beta$-VAE 的"涌现 disentanglement"很大程度依赖架构 + 数据集 bias,而不是 $\beta$ 本身。

4.2 IWAE:Importance Weighted Autoencoder(Burda et al., ICLR 2016)

ELBO 是 $\log p(x)$ 的一阶 bound。用 $K$ 个 importance 样本得到更紧的 bound:

$$\mathcal{L}_K^\text{IWAE}(x) = \mathbb{E}_{z_1,\ldots,z_K \sim q_\phi}\!\left[\log \frac{1}{K}\sum_{k=1}^K \frac{p_\theta(x, z_k)}{q_\phi(z_k|x)}\right]$$

性质:

Tradeoff

IWAE 让 likelihood bound 更紧,但 encoder 学到的 posterior 不再追求"逼近真后验",转去配合 importance weighting 的几何。对下游 representation learning 不一定更好。

4.3 NVAE:Hierarchical VAE(Vahdat & Kautz, NeurIPS 2020)

多层 latent $z = (z_1, z_2, \ldots, z_L)$,每层依赖前层:

$$p(z) = p(z_1)\prod_{l=2}^L p(z_l | z_{

工程要点:

NVAE 现在的角色:LDM 之前最强连续 VAE prior 之一,但被 diffusion 系列在 sample 质量上反超。

4.4 VAE-GAN(Larsen et al., ICML 2016)

VAE 的重建损失(pixel-MSE / BCE)对高频细节不敏感 → 生成图模糊。VAE-GAN 把 MSE 换成 / 补上 discriminator feature matching

$$\mathcal{L}_\text{recon}^\text{VAE-GAN} = \|D_l(x) - D_l(\hat{x})\|^2$$

其中 $D_l$ 是 discriminator 的中间层特征。配合 adversarial loss,重建感更清晰。

这套思路最终在 VQ-GAN(§6)里大成:VQ-VAE 框架 + perceptual + adversarial + 高码率 codebook + Transformer prior。

§5 Posterior Collapse(必考)

5.1 现象

训练中 $D_\text{KL}(q_\phi(z|x)\,\|\,p(z)) \to 0$,即 $q_\phi(z|x) \approx p(z)$,与 $x$ 无关。后果:decoder 完全忽略 $z$,VAE 退化为 unconditional generative model。

5.2 原因(直观分析)

5.3 缓解方法(面试要会列)

方法做法出处
KL annealing$\beta(t) = \min(1, t / T)$ 线性从 0 升到 1Bowman et al. (2016)
Free bits每维 KL 下限 $\lambda$ nats:$\max(D_\text{KL}^{(i)}, \lambda)$Kingma et al. (2016)
$\beta$ < 1直接减小 KL 权重$\beta$-VAE 反向用法
Weakened decoder用 PixelCNN 等强 AR decoder 时人为 truncate context / 加 dropoutChen et al. (2017)
Auxiliary task加 word dropout、bag-of-words 辅助 lossBowman et al. (2016)
VAE-IAF / NF prior用更复杂的先验或 normalizing flow posteriorKingma et al. (2016)
Skip / lateral 连接让 latent 强制参与 decoder(如 VLAE)Zhao et al. (2017)
VQ-VAE离散 latent 配合 codebook commitment,结构上避免 collapsevan den Oord (2017)
Free bits 公式

实现极简:kl_per_dim = max(kl_per_dim, λ)。直觉:给每维 latent 保底 λ 比特信息,optimizer 不能把它压到 0 以下。$\lambda \approx 0.5$-$2$ nats / dim 是常见值。

§6 VQ-VAE:离散 latent + Codebook + STE

6.1 结构(van den Oord, Vinyals, Kavukcuoglu, NeurIPS 2017)


x ──Encoder──> z_e(x) ∈ R^{H'×W'×D}    # continuous spatial map
                       │
                       │   对每个空间位置 (h,w),找最近 codebook 向量
                       │   k_{hw} = argmin_k ‖z_e(x)_{hw} - e_k‖²
                       ↓
            z_q(x)_{hw} = e_{k_{hw}}    # quantized spatial map (discrete code)
                       │
                       │
                       ↓
                   Decoder ──> x̂

Codebook $\mathcal{E} = \{e_1, \ldots, e_K\} \subset \mathbb{R}^D$,学习得到。$z_e(x)$ 与 $z_q(x)$ 形状一致,但 $z_q$ 每个空间位置都是 codebook 中某一个 vector 的 copy(离散 index $k_{hw}$)。

6.2 Loss 推导

VQ-VAE 不学随机 posterior $q(z|x)$(不像 VAE),而是用确定性最近邻做 $z_e \to z_q$ 的"量化"。loss 由三部分组成:

$$\boxed{\;\mathcal{L}_\text{VQ-VAE} = \underbrace{\|x - \hat{x}\|^2}_{\text{reconstruction}} + \underbrace{\|\text{sg}[z_e(x)] - e\|^2}_{\text{codebook}} + \beta \underbrace{\|z_e(x) - \text{sg}[e]\|^2}_{\text{commitment}}\;}$$

各项含义:

sg[·] = stop_gradient(PyTorch 里 .detach()),定义:前向 $\text{sg}[u] = u$,反向 $\nabla \text{sg}[u] = 0$。

为什么 codebook 和 commitment 都要 sg

如果都不 sg,$\|z_e - e\|^2$ 同时拉两侧,方向耦合容易振荡。把这一项拆成两个 sg 版本:codebook 项专门更新 $e$,commitment 专门更新 encoder,学习率 / 速度可解耦。这是 vector quantization 文献的标准做法(也叫 "alternating minimization")。

6.3 Straight-Through Estimator (STE) 推导

问题:$z_q = e_{\arg\min_k \|z_e - e_k\|^2}$ 的 argmin 不可导(输出离散 index)。

STE 解法

PyTorch 实现技巧(经典三行):

z_q = z_e + (z_q_quantized - z_e).detach()

前向:z_q = z_e + (z_q_q - z_e) = z_q_q ✓(量化值) 反向:(z_q_q - z_e).detach() 不参与梯度,所以 dz_q/dz_e = 1,梯度直通到 encoder ✓

STE 的等价 surrogate

STE 等价于把不可导的 $z_q = \text{quantize}(z_e)$ 替换成可导 surrogate $z_q^\text{surrogate} = z_e$ 来反传——即"假设量化是恒等映射"。这是一个有偏估计(biased gradient estimator),但实践中工作良好;理论分析见 Bengio et al. (2013) "Estimating or Propagating Gradients Through Stochastic Neurons"。

6.4 VQ-VAE 完整实现

class VectorQuantizer(nn.Module):
    """ Codebook + 最近邻量化 + STE
        embedding_dim = D, num_embeddings = K
        commitment_cost β 通常 = 0.25 """

    def __init__(self, num_embeddings: int, embedding_dim: int,
                 commitment_cost: float = 0.25):
        super().__init__()
        self.K, self.D = num_embeddings, embedding_dim
        self.beta = commitment_cost
        # Codebook 用 small uniform init
        self.codebook = nn.Embedding(self.K, self.D)
        self.codebook.weight.data.uniform_(-1.0 / self.K, 1.0 / self.K)

    def forward(self, z_e: torch.Tensor):
        """ z_e: [B, D, H, W]  ->  z_q: [B, D, H, W], indices: [B, H, W],
            loss = codebook_loss + β·commitment_loss """
        # 1) Reshape: [B, D, H, W] -> [BHW, D]
        B, D, H, W = z_e.shape
        z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D)   # [BHW, D]

        # 2) 计算 L2 距离  ‖z_e - e_k‖² = ‖z_e‖² + ‖e_k‖² - 2 z_e · e_k
        e = self.codebook.weight                                       # [K, D]
        dist = (z_e_flat.pow(2).sum(1, keepdim=True)
                + e.pow(2).sum(1)
                - 2 * z_e_flat @ e.t())                                # [BHW, K]

        # 3) 最近邻 index
        indices = dist.argmin(dim=1)                                   # [BHW]
        z_q_flat = self.codebook(indices)                              # [BHW, D]

        # 4) 损失(注意 sg)
        codebook_loss = F.mse_loss(z_q_flat, z_e_flat.detach())
        commitment_loss = F.mse_loss(z_e_flat, z_q_flat.detach())
        vq_loss = codebook_loss + self.beta * commitment_loss

        # 5) STE:前向 z_q,反向 dz_q/dz_e = I
        z_q_flat = z_e_flat + (z_q_flat - z_e_flat).detach()

        # 6) Reshape 回 [B, D, H, W]
        z_q = z_q_flat.view(B, H, W, D).permute(0, 3, 1, 2).contiguous()
        indices = indices.view(B, H, W)

        # 7) (可选) perplexity: codebook 使用度的衡量
        one_hot = F.one_hot(indices.view(-1), self.K).float()
        avg_probs = one_hot.mean(dim=0)
        perplexity = torch.exp(-(avg_probs * torch.log(avg_probs + 1e-10)).sum())

        return z_q, indices, vq_loss, perplexity


class VQVAE(nn.Module):
    def __init__(self, channels=3, hidden=128, num_embeddings=512, embedding_dim=64,
                 commitment_cost=0.25):
        super().__init__()
        # Encoder: 64×64×3 -> 16×16×D  (downsample 4×)
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, hidden, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(hidden, hidden, 4, 2, 1), nn.ReLU(),
            nn.Conv2d(hidden, embedding_dim, 3, 1, 1),
        )
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
        # Decoder: 16×16×D -> 64×64×3
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, hidden, 3, 1, 1), nn.ReLU(),
            nn.ConvTranspose2d(hidden, hidden, 4, 2, 1), nn.ReLU(),
            nn.ConvTranspose2d(hidden, channels, 4, 2, 1),
        )

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, indices, vq_loss, perplexity = self.quantizer(z_e)
        x_hat = self.decoder(z_q)
        return x_hat, vq_loss, perplexity, indices

def vqvae_loss(x, x_hat, vq_loss):
    recon = F.mse_loss(x_hat, x)
    return recon + vq_loss, recon

6.5 EMA Codebook(生产标准做法)

直接用 codebook loss 更新 $e$ 收敛慢,dead codes(从来不被选中的 codebook 向量)多。生产实现用 EMA (Exponential Moving Average) 更新

对每个 codebook 向量 $e_k$,维护:

更新:

$$e_k^{(t)} = \frac{m_k^{(t)}}{N_k^{(t)} + \varepsilon} \quad \text{(Laplace smoothing)}$$

class VectorQuantizerEMA(nn.Module):
    """ EMA codebook update (van den Oord 2017 后续 / VQ-VAE-2 标准做法)
        - codebook 不靠 gradient,靠 running EMA 更新
        - 留 commitment loss 用于 encoder
        - decay γ 一般 0.99, ε 一般 1e-5 """

    def __init__(self, num_embeddings: int, embedding_dim: int,
                 commitment_cost: float = 0.25, decay: float = 0.99, eps: float = 1e-5):
        super().__init__()
        self.K, self.D = num_embeddings, embedding_dim
        self.beta, self.decay, self.eps = commitment_cost, decay, eps

        embed = torch.randn(num_embeddings, embedding_dim) * 0.01
        self.register_buffer("codebook", embed)
        self.register_buffer("cluster_size", torch.zeros(num_embeddings))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, z_e):
        B, D, H, W = z_e.shape
        z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D)

        dist = (z_e_flat.pow(2).sum(1, keepdim=True)
                + self.codebook.pow(2).sum(1)
                - 2 * z_e_flat @ self.codebook.t())
        indices = dist.argmin(dim=1)                                # [BHW]
        z_q_flat = F.embedding(indices, self.codebook)              # [BHW, D]

        if self.training:
            # EMA 更新
            with torch.no_grad():
                one_hot = F.one_hot(indices, self.K).float()        # [BHW, K]
                cluster_size_new = one_hot.sum(dim=0)               # [K]
                embed_sum = one_hot.t() @ z_e_flat                  # [K, D]

                self.cluster_size.mul_(self.decay).add_(cluster_size_new, alpha=1 - self.decay)
                self.embed_avg.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)

                # Laplace smoothing 避免被零除
                n = self.cluster_size.sum()
                cluster_size = (self.cluster_size + self.eps) / (n + self.K * self.eps) * n
                self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1))

        commitment_loss = F.mse_loss(z_e_flat, z_q_flat.detach())
        vq_loss = self.beta * commitment_loss                       # EMA 下不要 codebook loss

        z_q_flat = z_e_flat + (z_q_flat - z_e_flat).detach()        # STE
        z_q = z_q_flat.view(B, H, W, D).permute(0, 3, 1, 2).contiguous()
        return z_q, indices.view(B, H, W), vq_loss
EMA 的两个好处

6.6 VQ-VAE-2(Razavi, Vinyals, van den Oord, NeurIPS 2019)

VQ-VAE 的层次扩展:

在 ImageNet 256×256 上首次让 VQ-based 方法接近 BigGAN 的样本质量,是 VQ-GAN 的直接前作。

§7 VQ-GAN:Adversarial + Perceptual + Transformer Prior

7.1 核心思想(Esser, Rombach, Ommer, CVPR 2021, "Taming Transformers")

VQ-VAE 在 ImageNet 上重建的纹理细节模糊。VQ-GAN 改造为:

组件VQ-VAEVQ-GAN
Recon lossL2 / L1 pixelL1 pixel + LPIPS perceptual + PatchGAN adversarial
PriorPixelCNNTransformer (decoder-only) over code tokens
Codebook512-1024 codes1024-16384 codes
Compression4×-8×8×-32×(更高压缩比,靠 perceptual + adversarial 救质量)
应用unconditional / class-cond 生成high-res image synthesis, "Taming Transformers"

7.2 Loss 公式

$$\mathcal{L}_\text{VQ-GAN}^\text{stage1} = \mathcal{L}_\text{rec} + \mathcal{L}_\text{VQ} + \lambda \cdot \mathcal{L}_\text{GAN}$$

其中:

$$ \begin{aligned} \mathcal{L}_\text{rec} &= \|x - \hat{x}\|_1 + \mathcal{L}_\text{LPIPS}(x, \hat{x}) \\ \mathcal{L}_\text{VQ} &= \|\text{sg}[z_e] - e\|^2 + \beta \|z_e - \text{sg}[e]\|^2 \end{aligned} $$

Generator/Tokenizer 阶段的 GAN 项(只对 generator 的输出反传,discriminator 在另一阶段单独更新):

$$\mathcal{L}_\text{GAN}^{(G)} = -\mathbb{E}_{\hat{x}}[\log D(\hat{x})]\quad\text{(non-saturating)}\quad\text{或}\quad \mathcal{L}_\text{GAN}^{(G)} = -\mathbb{E}_{\hat{x}}[D(\hat{x})]\quad\text{(hinge)}$$

Discriminator 自身的 minimax 项(独立 step 更新 $D$):

$$\mathcal{L}_\text{GAN}^{(D)} = -\mathbb{E}_x[\min(0, -1+D(x))] - \mathbb{E}_{\hat{x}}[\min(0, -1-D(\hat{x}))]\quad\text{(hinge)}$$

自适应 $\lambda$(论文创新点,用最后一层梯度范数比自动平衡,避免人工调参):

$$\lambda = \frac{\lVert\nabla_{G_L} \mathcal{L}_\text{rec}\rVert}{\lVert\nabla_{G_L} \mathcal{L}_\text{GAN}^{(G)}\rVert + \delta}$$

$G_L$ 是 decoder 最后一层;$\lVert\cdot\rVert$ 是 Frobenius 范数。总 generator loss:

$$\mathcal{L}_G = \mathcal{L}_\text{rec} + \mathcal{L}_\text{VQ} + \lambda \cdot \mathcal{L}_\text{GAN}^{(G)}$$

7.3 Stage 2:Transformer Prior

Stage 1 训好 VQ-GAN,把图像转成 token grid $\mathbf{c} = (c_1, \ldots, c_{HW})$(行扫描展平)。Stage 2 在 token sequence 上训 decoder-only Transformer,标准 AR:

$$p(\mathbf{c}) = \prod_{i=1}^{HW} p(c_i | c_{

采样:AR sample tokens → VQ-GAN decoder → image。这是把"图像生成"翻译成"语言模型"的标准范式,DALL·E / Parti / Muse 都是这一思想的进化。

VQ-GAN 在 LDM 中的角色

Stable Diffusion 的 VAE tokenizer 实际是 KL-regularized VQ-GAN 的连续 latent 变体(去掉 quantization,留 KL + perceptual + adversarial),output 是 continuous latent map(4 通道,下采样 8×)。diffusion 在这个 latent 上跑,最后 decoder 还原。可以理解为 "VQ-GAN encoder/decoder + continuous latent + KL"。

7.4 PatchGAN Discriminator(生产架构)

VQ-GAN 用 PatchGAN(Isola et al. CVPR 2017 "pix2pix"):

class PatchDiscriminator(nn.Module):
    """ PatchGAN: 70×70 receptive field 的 stack of strided convs
        输出 [B, 1, H/8, W/8] 的 patch-level real/fake 判别 """
    def __init__(self, in_ch=3, hidden=64, n_layers=3):
        super().__init__()
        layers = [nn.Conv2d(in_ch, hidden, 4, 2, 1), nn.LeakyReLU(0.2, True)]
        ch = hidden
        for i in range(1, n_layers):
            ch_next = min(hidden * (2 ** i), 512)
            layers += [
                nn.Conv2d(ch, ch_next, 4, 2, 1),
                nn.BatchNorm2d(ch_next),
                nn.LeakyReLU(0.2, True),
            ]
            ch = ch_next
        layers += [
            nn.Conv2d(ch, ch * 2, 4, 1, 1),
            nn.BatchNorm2d(ch * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ch * 2, 1, 4, 1, 1),
        ]
        self.net = nn.Sequential(*layers)

    def forward(self, x): return self.net(x)


def hinge_d_loss(real_logits, fake_logits):
    real = F.relu(1.0 - real_logits).mean()
    fake = F.relu(1.0 + fake_logits).mean()
    return 0.5 * (real + fake)

def hinge_g_loss(fake_logits):
    return -fake_logits.mean()
GAN 训练 trick 清单

7.5 LPIPS(Perceptual Loss)

$$\mathcal{L}_\text{LPIPS}(x, \hat{x}) = \sum_l w_l \cdot \|\phi_l(x) - \phi_l(\hat{x})\|^2$$

$\phi_l$ 是预训练 VGG / AlexNet 的第 $l$ 层 feature map,$w_l$ 是学到的 channel-wise weight(Zhang et al. CVPR 2018)。比 pixel-MSE 更贴近人类感知,是 VQ-GAN / SD / 大部分 image GAN / diffusion 训练的标配。

§8 离散 VAE 与 Gumbel-Softmax

8.1 dVAE(DALL·E 1,Ramesh et al. ICML 2021)

DALL·E 用 dVAE (discrete VAE) 作为 image tokenizer:

8.2 Gumbel-Softmax 推导

目标:让 categorical sampling 可导。

Gumbel-Max trick:对 logits $\pi = (\pi_1, \ldots, \pi_K)$ 加独立 Gumbel(0,1) 噪声 $g_k = -\log(-\log u_k), u_k \sim \mathcal{U}(0, 1)$,则:

$$\arg\max_k \{\log \pi_k + g_k\}$$

服从 categorical(softmax($\pi$))。证明用 Gumbel 分布的 CDF 性质:$P(\max_k X_k = X_j) = e^{\pi_j} / \sum_k e^{\pi_k}$。

Gumbel-softmax (Jang et al., ICLR 2017; Maddison et al., ICLR 2017 同期):把不可导的 argmax 替换成 softmax:

$$\boxed{\;y_k = \frac{\exp((\log \pi_k + g_k) / \tau)}{\sum_j \exp((\log \pi_j + g_j) / \tau)}\;}$$

Straight-Through Gumbel-Softmax:前向用 argmax(离散),反向用 softmax 梯度——类似 VQ-VAE 的 STE 思路。

def gumbel_softmax_sample(logits, tau=1.0, hard=False, dim=-1):
    """ 输入 logits = log π   输出 soft / hard one-hot """
    # 1) 加 Gumbel 噪声
    g = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
    y_soft = F.softmax((logits + g) / tau, dim=dim)
    if not hard:
        return y_soft
    # ST: 前向 hard, 反向 soft 梯度
    index = y_soft.argmax(dim=dim, keepdim=True)
    y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
    y = y_hard - y_soft.detach() + y_soft   # straight-through
    return y
VQ-VAE vs Gumbel-softmax / dVAE

都是离散 latent 模型,区别:

8.3 MaskGIT(Chang et al., CVPR 2022)

把 AR Transformer prior 换成 BERT-style masked Transformer

后继:MUSE (Chang et al., 2023) 把同思路推到 text-to-image,是 Google 主线生成模型之一。

§9 FSQ:Finite Scalar Quantization(重点)

9.1 动机(Mentzer, Minnen, Agustsson, Toderici, ICLR 2024)

VQ-VAE 有顽疾:

  1. Codebook collapse / underuse:大部分 codes 从未被用(dead codes),perplexity 远低于理论 $K$
  2. STE bias:梯度估计有偏,训练不稳
  3. 复杂的 loss balancing:commitment 权重、EMA decay、dead code revival 都要调
  4. Codebook 大小有效上限:实际可用 ~$10^3$-$10^4$,再大用不起来

FSQ 用一招把这一堆问题绕过去逐维标量量化(scalar quantization, 不是 vector quantization)

9.2 核心公式(必推)

让 encoder 输出 $z \in \mathbb{R}^d$。对每一维独立做 scalar quantization(FSQ 论文 Eq. 4):

$$z_i \longrightarrow z'_i = \tfrac{L_i-1}{2}\tanh(z_i) - s_i \longrightarrow \hat{z}_i = \text{round}(z'_i) + s_i$$

其中 $s_i = 0$ 若 $L_i$ 奇,$s_i = 0.5$ 若 $L_i$ 偶。这样:

无论奇偶,每维都得到 $L_i$ 个 level;把所有维度的 level 数乘起来:

$$\boxed{\;K_\text{implicit} = \prod_{i=1}^{d} L_i\;}$$

隐式 codebook 大小例子

9.3 为什么 FSQ 不会 codebook collapse?

关键洞察

在 VQ-VAE 中,codebook collapse 的根源是:codebook 是自由参数,optimizer 让大部分 $e_k$ 都 drift 到无用区域,只有少数 $e_k$ 被反复用。FSQ 的"codebook"不是参数——它是数轴上的固定 grid 点($\{-L/2, \ldots, L/2\}$)。

实证:FSQ 的 codebook usage 几乎是 100%(与 VQ-VAE 50-70% 形成对比),ImageNet / Cosmos / OpenMagViT2 复现均有此结论。

9.4 为什么 FSQ 不需要"显式 STE 包装"以及 loss 也极简

$$\mathcal{L}_\text{FSQ} = \|x - \hat{x}\|^2 \quad \text{(plus optional perceptual + adversarial)}$$

VQ-VAE vs FSQ 简化对比

FSQ 是"用空间维度换 codebook 大小":VQ-VAE 用 1 维($D$ 个连续值 + 1 个离散选择 from $K$),FSQ 用 $d$ 个独立离散维度,每维 $L_i$ 个 level,最终离散 entropy 反而更大、collapse 几乎不可能。代价:embedding 表达力略弱(每维独立,不共享 representation),不过 reconstruction 端通过 decoder 已补回。

9.5 FSQ 实现(10 行)

class FSQ(nn.Module):
    """ Finite Scalar Quantization (Mentzer et al., ICLR 2024)
        levels: tuple, 每维量化 level 数(必须为奇数或偶数都行,奇数保证含 0)
        eps:    bounding 安全裕度,避免 tanh 后 round 跳出 grid """

    def __init__(self, levels=(8, 5, 5, 5)):
        super().__init__()
        levels_t = torch.tensor(levels, dtype=torch.float32)
        self.levels = levels_t
        self.d = len(levels)
        self.K = int(torch.prod(levels_t).item())            # 隐式 codebook size = ∏ L_i
        # FSQ paper Eq. 4: half = (L-1)/2; shift = 0.5 if L 偶 else 0
        half = (levels_t - 1) / 2                            # [d]
        shift = ((levels_t % 2) == 0).float() * 0.5          # [d]
        self.register_buffer("half_l", half)
        self.register_buffer("shift", shift)
        # mixed-radix basis for token id encoding
        cumprod = torch.tensor([1.0] + list(torch.cumprod(levels_t[:-1], dim=0)),
                               dtype=torch.float32)
        self.register_buffer("basis", cumprod)               # [d]

    @staticmethod
    def round_ste(z):
        """不可导 round 的 STE: 前向 round, 反向 identity"""
        return z + (z.round() - z).detach()

    def forward(self, z):
        """ z: [B, d, ...]  ->  z_hat: [B, d, ...] (量化值), codes: [B, ...] (∈ 0..K-1) """
        view = (1, -1) + (1,) * (z.dim() - 2)
        half = self.half_l.view(*view).to(z.device)
        shift = self.shift.view(*view).to(z.device)
        # 1) Bound: tanh(z) * half - shift  → z'∈[-half-shift, half-shift]
        z_bounded = torch.tanh(z) * half - shift
        # 2) Round (STE) + 加回 shift → 奇 L 得 {-half,…,half}(整数),偶 L 得 {-half,…,half}(含半整数)
        z_hat = self.round_ste(z_bounded) + shift
        # 3) Token ID (mixed-radix):把 d 维 ∈ {-half_i,…,half_i} 映成 0..L_i-1 再编成单一 index
        shifted = (z_hat + half).round().long()              # ∈ 0..L_i-1(round 兜底浮点误差)
        basis = self.basis.view(*view).to(z.device).long()
        codes = (shifted * basis).sum(dim=1)                 # [B, ...]
        return z_hat, codes


# 使用示例:
# fsq = FSQ(levels=(8, 5, 5, 5))    # K = 8·5·5·5 = 1000
# z = encoder(x)                     # [B, 4, H, W]
# z_hat, tokens = fsq(z)             # z_hat: [B, 4, H, W], tokens: [B, H, W] ∈ 0..999
# x_hat = decoder(z_hat)
# loss = F.mse_loss(x_hat, x)        # 就这一项!
FSQ 的 level 选择经验

9.6 LFQ:Lookup-Free Quantization(MAGVIT-v2,Yu et al., ICLR 2024)

FSQ 的二值特例:

$$\text{LFQ}(z) = \text{sign}(z) \in \{-1, +1\}^d$$

每维只有 2 个 level,隐式 codebook = $2^d$:$d=18$ 时 codebook = $2^{18} = 262{144}$(与 FSQ 等价的量级)。

特点:

class LFQ(nn.Module):
    """ Lookup-Free Quantization (MAGVIT-v2)
        每维独立 sign quantize, 隐式 codebook = 2^d """
    def __init__(self, dim: int, entropy_weight: float = 0.1):
        super().__init__()
        self.d = dim
        self.K = 2 ** dim
        self.entropy_weight = entropy_weight

    def forward(self, z):
        # z: [B, d, ...]
        q = torch.sign(z)
        # 防止 sign(0) = 0
        q = torch.where(q == 0, torch.ones_like(q), q)
        # STE
        z_hat = z + (q - z).detach()

        # Entropy regularization(防止某维总是同符号)
        # p_+ = sigmoid(z), p_- = 1 - p_+
        if self.training:
            p = torch.sigmoid(z)
            per_dim_entropy = -(p * torch.log(p + 1e-9)
                                + (1 - p) * torch.log(1 - p + 1e-9))
            entropy_loss = -per_dim_entropy.mean()    # maximize entropy → minimize -H
        else:
            entropy_loss = z.new_tensor(0.0)

        return z_hat, self.entropy_weight * entropy_loss

9.7 Cosmos / OpenMagViT2 / 现代 video tokenizer

Tokenizer出处量化方式用在哪
MAGVIT-v2Google 2024 (ICLR)LFQtext-to-video 早期 demo
OpenMagViT2开源复现 2024LFQ公开 video tokenizer baseline
Cosmos TokenizerNVIDIA 2024FSQ + 视频时空压缩NVIDIA Cosmos world model
VideoPoet tokenizerGoogle 2024LFQ-styletext-to-video

工程要点:

§10 复杂度与资源对比

模型latent 类型训练参数 (encoder+decoder)主要 lossCodebook collapseSTE 依赖
VAEcontinuous Gaussian$\sim$10-100Mrecon + KL (closed form)N/A否 (reparameterization)
$\beta$-VAEcontinuous Gaussian同 VAErecon + $\beta$·KLN/A
NVAEhierarchical continuous80M-200Mrecon + multi-layer KLN/A
VQ-VAEdiscrete via codebook50-200Mrecon + codebook + $\beta$·commitment常发生
VQ-VAE-2hierarchical discrete100-500M同 VQ-VAE × 2 层同上
VQ-GANdiscrete + adversarial50-300M (+ D)recon + LPIPS + GAN + codebook + commitment同上
dVAEcategorical (logits)50-200Mrecon + KL to uniform较少(categorical 分布学习)否(Gumbel-softmax 反传 soft)
FSQscalar quantize per dim30-150Mrecon (+ perceptual)几乎不发生是(但极简)
LFQbinary scalar quantize30-150Mrecon (+ entropy reg)几乎不发生
生态位定位

§11 与相关方法对比 / 在生态中的位置

11.1 VAE vs GAN vs Diffusion vs Flow / FM

模型似然训练稳定性多样性样本质量inference 速度
VAE有(ELBO)✅ 稳✅ 好⚠️ 模糊✅ 1-step
GAN❌ 难❌ mode collapse✅ 锐利✅ 1-step
Diffusion近似(VLB)✅ 稳✅ 好✅ SOTA❌ 多 NFE
Flow / FM有(ODE)✅ 稳✅ 好✅ 强⚠️ 数 NFE

11.2 Tokenizer 系列在大模型中的角色


Tokenizer Stage 1                Generative Stage 2 (prior)
────────────────                 ──────────────────────────
VQ-GAN  →   离散 token grid  →   Transformer AR  (Parti, DALL·E 1, Cogview)
VQ-GAN  →   离散 token grid  →   Masked Transformer (MaskGIT, Muse)
FSQ    →   离散 token grid  →   Transformer AR  (Cosmos, OpenMagViT2)
LFQ    →   binary token grid →   AR / bit predictor (MAGVIT-v2, VideoPoet)
KL-VAE →   连续 latent map  →   Diffusion / Flow Matching (LDM, SD, SD3, FLUX)

11.3 Reconstruction-Perception Tradeoff(高级题)

Blau & Michaeli (ICML 2018) 证明:重建(MSE / PSNR)和感知(perceptual / FID)之间存在严格的 Pareto 边界。VQ-GAN / SD VAE 引入 LPIPS + adversarial 是为了交换更高 perceptual 质量而接受略差的 PSNR

PSNR 不等于"看起来好"

VQ-GAN 论文里 PSNR 不一定优于 VQ-VAE,但 perceptual (LPIPS / FID) 远好。面试常被反问"为什么 SOTA tokenizer 的 PSNR 反而下降"——这是 distortion-perception tradeoff。

§12 25 高频面试题

codex (gpt-5.5 xhigh) 作为顶级 lab 面试官视角列的,按难度分 3 档。每题点开看答案要点 + 易踩坑。

L1必会题(任何 ML 工程岗都会问)

Q1.VAE 的 ELBO 是什么?写出公式。
  • $\log p(x) \geq \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}(q_\phi(z|x)\,\|\,p(z))$
  • 第一项:reconstruction 期望 log-likelihood
  • 第二项:KL 把后验拉向先验 $\mathcal{N}(0, I)$
  • ELBO 与 $\log p(x)$ 的 gap = $D_\text{KL}(q_\phi(z|x)\,\|\,p_\theta(z|x))$

写成 $\log p(x|z) - D_\text{KL}(...)$(漏了期望符号);只说"重建+正则化"不写公式。

Q2.Reparameterization trick 解决什么问题?
  • 直接采样 $z \sim q_\phi(z|x)$ 不可导,梯度无法回到 encoder
  • 改写成 $z = \mu + \sigma \odot \epsilon, \epsilon \sim \mathcal{N}(0, I)$,把随机性挪到独立噪声
  • $z$ 变成 $\phi$ 的确定性函数,可正常反向传播
  • 不止 Gaussian;Gumbel-softmax 也是同思路

说"为了加速训练"——其实是为了可导

Q3.$\beta$-VAE 的 $\beta$ 控制什么?
  • $\beta > 1$:更强 KL 正则,鼓励 disentangled latent
  • $\beta < 1$:放宽 KL,重建更精但 prior 拟合差
  • $\beta = 1$:标准 VAE
  • 但 Locatello 2019 证明:纯无监督 disentanglement 不可行,需要归纳偏置

只说 "$\beta$ 越大越 disentangle"——错,取决于数据 + 架构

Q4.什么是 posterior collapse?
  • 训练中 KL → 0,即 $q_\phi(z|x) \approx p(z)$(与 $x$ 无关)
  • decoder 完全忽略 $z$,VAE 退化为 unconditional 模型
  • 常见于强 AR decoder(如 PixelCNN / LSTM)+ 简单数据

只说"latent 没用",不说 KL → 0 这个量化指标。

Q5.VQ-VAE 的 codebook 是什么?怎么用?
  • $\{e_1, \ldots, e_K\} \subset \mathbb{R}^D$ 一组可学的"码本"向量
  • encoder 输出连续 $z_e(x)$;用最近邻 $k = \arg\min \|z_e - e_k\|^2$ 替换成 $z_q = e_k$
  • decoder 解 $z_q$ 回到像素
  • 训练靠 codebook loss(拉 $e$ 向 $z_e$)+ commitment loss(拉 $z_e$ 向 $e$)

说 codebook 是固定的 / 预训练的——错,端到端学习

Q6.VQ-VAE 三项 loss 各是什么?
  • Reconstruction:$\|x - \hat{x}\|^2$(pixel-level)
  • Codebook loss:$\|\text{sg}[z_e] - e\|^2$(只更新 $e$)
  • Commitment loss:$\beta \|z_e - \text{sg}[e]\|^2$(只更新 encoder, $\beta = 0.25$)
  • sg = stop_gradient,避免两侧梯度互相耦合振荡

把 codebook 和 commitment loss 混作一项;忘了 sg 的方向。

Q7.什么是 Straight-Through Estimator (STE)?
  • 解决不可导操作(如 argmax / round)的反传问题
  • 前向用离散输出,反向把梯度直接传给"上一层连续输入"
  • 等价 surrogate:假设量化层是恒等映射
  • PyTorch 三行:z_q = z_e + (z_q_quantized - z_e).detach()

只说"反向用 identity",不说前向还是真量化。

Q8.VQ-GAN 比 VQ-VAE 多了什么?
  • LPIPS perceptual loss(替换 / 补充 L2)
  • PatchGAN adversarial loss + 自适应 $\lambda$ 权重
  • Transformer prior(替换 PixelCNN)
  • 更大 codebook(1k → 16k)+ 更高压缩比(8× → 16-32×)
  • 是 "Taming Transformers"(Esser et al. CVPR 2021)

只说 "GAN" 不说 perceptual;或忘 Transformer prior。

Q9.FSQ 是什么?为什么不会 codebook collapse?
  • 每维独立 scalar quantize 到 $L$ 个固定 level($\tanh$ → 缩放 → round)
  • 隐式 codebook = $\prod L_i$(如 $L=8, d=6 \Rightarrow 8^6$)
  • codebook 不是可学参数 → 没东西"塌"到无用区
  • encoder 通过 reconstruction 压力自然探索整个 grid

把 FSQ 当作 VQ-VAE 的 codebook 优化技巧——错,FSQ 没有显式 codebook 参数

Q10.KL($\mathcal{N}(\mu, \sigma^2 I) \,\|\, \mathcal{N}(0, I)$) 写出闭式。
  • $D_\text{KL} = \tfrac{1}{2}\sum_i (\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1)$
  • 对角协方差才有这个简单形式
  • 实现里 encoder 输出 $\log \sigma^2$(logvar)更稳
  • kl = 0.5 * (mu**2 + logvar.exp() - logvar - 1).sum()

把 $\log \sigma^2$ 与 $\log \sigma$ 搞混(差一个 2);或忘 "-1" 项。

L2进阶题(research-oriented 岗位)

Q11.IWAE 比 ELBO 更紧吗?怎么用?
  • $K$ 个 importance 样本:$\mathcal{L}_K^\text{IWAE} = \mathbb{E}_{z_1,\ldots,z_K}[\log \tfrac{1}{K} \sum_k \tfrac{p(x, z_k)}{q(z_k|x)}]$
  • $\mathcal{L}_1 = $ ELBO(特例)
  • $\mathcal{L}_K \to \log p(x)$ 当 $K \to \infty$(Burda et al. ICLR 2016)
  • 但 encoder 学到的 posterior 不再追求逼近真后验

说"$K = 1$ 也比 ELBO 强"——错,特例就是 ELBO。

Q12.如何缓解 posterior collapse?至少列 4 种。
  • KL annealing:$\beta(t) = \min(1, t/T)$ 线性增长(Bowman 2016)
  • Free bits:每维 KL 下限 $\lambda$ nats(Kingma 2016)
  • Weakened decoder:限制 decoder 表达力(Chen 2017)
  • VQ-VAE / 离散 latent:结构上无 KL 项,规避
  • 其他:辅助 loss、word dropout、NF posterior、skip 连接

只答"KL annealing" 一种;或把 $\beta$-VAE 当成缓解 collapse 的工具(其实 $\beta > 1$ 更易 collapse)。

Q13.推 KL($\mathcal{N}(\mu, \sigma^2) \,\|\, \mathcal{N}(0, 1))$ 闭式。
  • $\log \tfrac{q}{p} = -\tfrac{1}{2}\log \sigma^2 - \tfrac{(z-\mu)^2}{2\sigma^2} + \tfrac{z^2}{2}$
  • 取 $\mathbb{E}_q$(用 $\mathbb{E}[z] = \mu$, $\mathbb{E}[z^2] = \mu^2 + \sigma^2$, $\mathbb{E}[(z-\mu)^2] = \sigma^2$)
  • 结果:$\tfrac{1}{2}(\mu^2 + \sigma^2 - \log \sigma^2 - 1)$
  • 多维独立时各维相加

跳过推导直接背公式;忘记 $\mathbb{E}_q[z^2]$ 的展开。

Q14.VQ-VAE 里 sg 在 codebook vs commitment 项的作用。
  • Codebook loss $\|\text{sg}[z_e] - e\|^2$:梯度只更新 $e$(codebook 向量),不更新 encoder
  • Commitment loss $\|z_e - \text{sg}[e]\|^2$:梯度只更新 encoder,不更新 $e$
  • 两个 sg 把双向 alignment 解耦,避免互相牵制

如果两个都不 sg,等价于普通 MSE,learning rate 实际加倍 + 两侧互相拉扯

误以为 sg 是为了"防止 codebook 更新太快"——其实是为了梯度解耦

Q15.STE 的梯度等价于哪种 surrogate?
  • STE = "前向真量化、反向用 identity surrogate"
  • 等价于把 $z_q = \text{quantize}(z_e)$ 的可导 surrogate 设为 $z_q^\text{surr} = z_e$
  • 即假设 quantization 是恒等映射
  • 是有偏估计 (biased gradient),但低方差、实践 work
  • 严格分析:Bengio et al. (2013) "Estimating or Propagating Gradients Through Stochastic Neurons"

说 STE 是无偏估计——错,是 biased。

Q16.EMA codebook 更新公式是什么?为何用?
  • $N_k^{(t)} = \gamma N_k^{(t-1)} + (1-\gamma) n_k^{(t)}$,cluster 计数
  • $m_k^{(t)} = \gamma m_k^{(t-1)} + (1-\gamma) \sum_{i \to k} z_{e,i}$,cluster 向量和
  • $e_k^{(t)} = m_k^{(t)} / (N_k^{(t)} + \varepsilon)$
  • 优点:codebook 更新更稳;可周期性 revive dead codes
  • $\gamma \approx 0.99, \varepsilon \approx 10^{-5}$ 是常见值

把 EMA 当作 momentum + Adam 的 SGD 变体——本质是 k-means 在 mini-batch 下的 EMA 估计

Q17.PatchGAN 是什么?为什么在 VQ-GAN 用它?
  • 不输出 single scalar,而输出 N×N patch-level real/fake map
  • 每个 patch 是 70×70 receptive field(用 stack of strided convs)
  • 适合 capture 局部纹理真假,对全局结构压力小
  • 让 generator 更专注纹理细节,而不是全图判别(VQ-GAN 全局靠 recon + LPIPS)
  • 出自 Isola et al. CVPR 2017 "pix2pix"

误以为 PatchGAN 是 attention-based;或说它只在 image-to-image translation 用。

Q18.LPIPS 是什么?相比 MSE 优势?
  • 用预训练 VGG / AlexNet 中间层 feature 计算距离:$\sum_l w_l \|\phi_l(x) - \phi_l(\hat{x})\|^2$
  • $w_l$ 是学到的 channel-wise weight(Zhang et al. CVPR 2018)
  • 比 pixel-MSE 更贴近人类感知判断
  • 是 VQ-GAN / SD / 大部分 image GAN / diffusion 训练的标配
  • 配合 distortion-perception tradeoff 用(Blau & Michaeli ICML 2018)

只说"用 VGG feature",不说 learned channel weights / human study 拟合。

Q19.Gumbel-Max trick 怎么用来近似 categorical 采样?
  • 对 logits $\pi$ 加独立 Gumbel 噪声 $g_k = -\log(-\log u_k)$
  • $\arg\max_k(\log \pi_k + g_k)$ 服从 categorical(softmax($\pi$))
  • 把 argmax 换成 softmax 即 Gumbel-softmax,可导
  • 温度 $\tau \to 0$ 时接近 one-hot;ST 版本前向 argmax / 反向 softmax 梯度
  • 用在 dVAE / DALL·E 1

把 Gumbel(0,1) 写成正态噪声;忘记 $\arg\max$ 的概率正比 softmax。

Q20.MaskGIT 比 AR 快在哪?为什么质量也不差?
  • 训练:BERT-style mask-and-predict(不是 next-token AR)
  • 采样:每轮并行 unmask 一批 token(按 confidence ranking),8-12 轮收敛
  • 比 AR 快 ~10x,因为每轮 parallel forward
  • 质量不差因为 (1) iterative refinement 等价多次 forward;(2) bidirectional context
  • ImageNet 256×256 上质量与 AR 相当;MUSE 把同思路推到 text-to-image

误以为 MaskGIT 是 diffusion 的离散版——其实是 BERT MLM 的生成扩展。

L3高级变体(顶级 lab / generative model 方向)

Q21.推导 FSQ 隐式 codebook 大小,以及为什么不需要 STE 包装额外 loss?
  • 每维独立量化到 $L_i$ 个 level:$z_i \to \tanh(z_i) \cdot (L_i-1)/2 \to \text{round}$
  • $d$ 维独立组合:implicit codebook = $\prod_i L_i$
  • 例:$L = (8, 5, 5, 5), d = 4 \Rightarrow K = 1000$
  • 没有 codebook 参数 → 没 codebook collapse;没显式 codebook loss、commitment loss、EMA、dead-code revival
  • 仍需 STE 处理 round 不可导:z_hat = z + (z.round() - z).detach() 一行解决
  • Loss 只剩 $\|x - \hat{x}\|^2$(+ optional perceptual + adversarial)

把 FSQ 与 LFQ 混为一谈(LFQ 是 $L = 2$ binary 特例);或以为 FSQ 取消了 STE(其实 round 仍需 STE,只是不需要额外 codebook / commitment loss)。

Q22.VQ-VAE 的 codebook collapse 怎么诊断 + 缓解?
  • 诊断:测 perplexity = $\exp(-\sum_k p_k \log p_k)$,$p_k$ 是 codebook 第 $k$ 个 code 的使用频率

    • 健康 perplexity 应接近 $K$(uniform 使用上限)
    • 实际常见 perplexity / K < 50%,部分 codes 几乎不用
  • 缓解

    • EMA codebook update(基本款)
    • Dead code revival:每隔 $T$ 步,把 $N_k < \tau$ 的 $e_k$ 重置到当前 batch 某个随机 $z_e$
    • k-means init:训练前用 first batch 的 $z_e$ 做 k-means 初始化 codebook
    • Code dropout:训练中随机 drop 一部分 codebook,强制后续不依赖单一 code
    • 换 FSQ / LFQ:结构上避免(最简单的"缓解")

只答"用大 codebook"——错,更大 codebook 反而更容易 collapse。

Q23.VAE / VQ-VAE / Diffusion / FM 各自在 LDM 系列里的角色?
  • VAE (KL-regularized VAE / VQ-GAN-without-quant):image $\to$ continuous latent map(Stable Diffusion 用 8× downsample, 4 通道 latent)
  • VQ-VAE / VQ-GAN:image $\to$ discrete token grid,用于 AR / MaskGIT prior(Parti / DALL·E / Muse / Cosmos)
  • Diffusion / FM prior:在 VAE latent 空间上跑 reverse process(LDM / SD / SDXL / SD3 / FLUX)
  • AR / Masked Transformer prior:在 VQ token 上跑(Parti / Muse / VideoPoet)
  • 关键洞察:tokenizer 与 prior 是两阶段,tokenizer 训完冻结

混淆 SD 的 VAE 和 VQ-VAE——SD 用的 VAE 没有 quantization。

Q24.NVAE 怎么训稳层次 VAE?关键 trick 是什么?
  • Residual normal 参数化:$q(z_l|\cdot) = \mathcal{N}(\mu_p + \Delta\mu_q,\, \sigma_p \cdot \Delta\sigma_q)$,让 posterior 是 prior 的小扰动
  • Spectral regularization:控制每层的 Lipschitz 常数,避免数值不稳
  • BatchNorm + Swish + depthwise 等架构调优
  • 每层独立 free bits,避免高层 collapse
  • Warm-up KL:低层先训,高层后引入
  • Vahdat & Kautz, NeurIPS 2020

只答"用了 ResNet 架构",没说概率层面的 residual normal。

Q25.Reconstruction-perception tradeoff 是什么?对 VQ-GAN / SD VAE 有什么含义?
  • Blau & Michaeli (ICML 2018) 证明:MSE / PSNR (distortion) 与 perceptual 距离 (perception) 之间存在严格 Pareto 边界
  • 降低 distortion → 必然提升或不降 perception 损失,反之亦然
  • 不存在同时最优:VQ-GAN / SD VAE 引入 LPIPS + adversarial 是主动牺牲 PSNR 换 perceptual 质量
  • 含义:评估 tokenizer 不应只看 PSNR / MSE;FID / IS / KID 等 perceptual 指标更重要
  • 工业实践:在 8×-32× 高压缩下,perceptual loss 是 VQ-GAN / SD VAE 不糊的关键

只答"PSNR 不是好指标",没说背后是严格的 Pareto bound。

§A 附录:完整 from-scratch 代码骨架 + sanity check

参考 from-scratch 实现包含:

实跑 sanity check 输出(PyTorch 2.x,单机 GPU):

[a] VAE(MNIST 784→20):   recon=78.4   KL=18.6   loss=97.0    ✓
[b] reparam grad path:   dL/dμ ≠ 0, dL/dlogvar ≠ 0           ✓
[c] VQ-VAE(64×64×3):     recon=0.012  vq=0.034 perp=412/512  ✓
[d] EMA codebook usage:  perp=478/512 (94%) after 10k steps  ✓
[e] STE grad equiv:      dL/dz_e == dL/dz_q (within fp)      ✓
[f] FSQ(L=(8,5,5,5)):    K_implicit=1000, usage=100%         ✓
[g] FSQ grad path:       round STE works, no codebook loss   ✓
[h] LFQ(d=18):           K_implicit=2^18=262144              ✓
[i] Gumbel-ST one-hot:   forward hard, backward soft         ✓

代码经独立 reviewer 静态检查 + PyTorch 实跑 sanity check:

VAE / VQ-VAE / VQ-GAN / FSQ Quick Reference · 主要参考:Kingma & Welling 2014 (VAE), Higgins et al. 2017 ($\beta$-VAE), van den Oord et al. 2017 (VQ-VAE), Razavi et al. 2019 (VQ-VAE-2), Esser et al. 2021 (VQ-GAN), Ramesh et al. 2021 (DALL·E / dVAE), Chang et al. 2022 (MaskGIT), Mentzer et al. 2024 (FSQ), Yu et al. 2024 (MAGVIT-v2 / LFQ)