VAE / VQ-VAE / VQ-GAN / FSQ 面试 Cheat Sheet
公式推导 + From-Scratch 代码 + 25 高频题(L1 必会 · L2 进阶 · L3 顶级 lab)
§0 TL;DR Cheat Sheet
一页拿下面试核心要点(详见后文 §2–§9 推导)。
- 连续 VAE 目标:最大化 ELBO,$\log p(x) \geq \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}(q_\phi(z|x)\,\|\,p(z))$;reparameterization $z = \mu + \sigma \odot \epsilon$ 让梯度穿过随机采样。
- KL 闭式(必考):$D_\text{KL}(\mathcal{N}(\mu,\sigma^2 I)\,\|\,\mathcal{N}(0,I)) = \tfrac{1}{2}\sum_i (\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1)$。
- Posterior collapse:KL → 0 → 解码器忽略 $z$;缓解:KL annealing、free bits、$\beta$ schedule、自回归先验。
- VQ-VAE:把 encoder 输出 $z_e(x)$ 映射到最近邻 codebook 向量 $e_k$,loss = recon + $\|\text{sg}[z_e] - e\|^2$(codebook)$+ \beta \|z_e - \text{sg}[e]\|^2$(commitment)。
- Straight-Through Estimator (STE):argmin / quantize 不可导,前向用量化值,反向直通梯度 $\partial \mathcal{L}/\partial z_q \to \partial \mathcal{L}/\partial z_e$。
- VQ-GAN:VQ-VAE + perceptual (LPIPS) + adversarial (PatchGAN) + 后训 Transformer prior;为 LDM / Parti / Muse 等离散 token 模型奠基。
- FSQ(2024):每维标量量化到 $\{-L,\ldots,L\}$,隐式 codebook 大小 $\prod_i L_i$(如 $L=8, d=6 \Rightarrow 8^6 = 262{144}$),无需 STE、不会 codebook collapse,rounding 用 STE 即可,loss 只剩 reconstruction。
- 生态对比:连续 latent(VAE / KL)适合 LDM 续 diffusion;离散 token(VQ-VAE / VQ-GAN / FSQ / LFQ)适合 AR / MaskGIT Transformer prior,是 Parti / Muse / Cosmos 等的核心组件。
§1 直觉:为什么要 latent variable model
生成模型的核心难题:直接建模 $p(x)$ 很难,但如果引入低维 latent $z$:
$$p(x) = \int p(x|z)\, p(z)\, dz$$
可以把"复杂的图像分布"分解为"简单先验 $p(z)$(如 $\mathcal{N}(0, I)$)"加上"易学的条件分布 $p(x|z)$"。两条路:
- 连续 latent(VAE):$z \in \mathbb{R}^d$,KL 把 posterior 拉向 Gaussian 先验,和 diffusion / FM 天然兼容(LDM 在 VAE latent 里跑 diffusion)。
- 离散 latent(VQ-VAE / VQ-GAN / FSQ):$z \in \mathcal{V}^{H \times W}$(token grid),和 Transformer / AR / MaskGIT 天然兼容(一张图变成一串 token,复用语言模型架构)。
VAE/VQ-VAE 训练时学整套 encoder + decoder(rate-distortion 角度:"压缩-重建");推理时根据应用分两种:
- 生成新样本:丢掉 encoder,从 prior 采样 $z$,过 decoder
- 下游 backbone:丢掉 decoder,把 encoder/$z$ 当作 representation 给后续模型
- 二阶段生成(LDM / Parti / Muse):先训 VAE/VQ-GAN tokenizer,再在 latent 空间训 diffusion / AR / MaskGIT prior。tokenizer 训完冻结。
§2 VAE:核心公式与推导
2.1 ELBO 推导(必考、要会逐行推)
模型族 $p_\theta(x, z) = p_\theta(x|z)\, p(z)$,先验 $p(z) = \mathcal{N}(0, I)$,似然 $p_\theta(x|z)$ 由 decoder 给出。Marginal likelihood:
$$\log p_\theta(x) = \log \int p_\theta(x|z) p(z)\, dz$$
对任意 distribution $q_\phi(z|x)$(encoder / variational posterior,$q_\phi(z|x) = \mathcal{N}(\mu_\phi(x), \mathrm{diag}(\sigma_\phi^2(x)))$),由 Jensen 不等式 / 直接代入:
$$ \begin{aligned} \log p_\theta(x) &= \log \int q_\phi(z|x) \frac{p_\theta(x|z) p(z)}{q_\phi(z|x)} dz \\ &\geq \mathbb{E}_{q_\phi(z|x)}\!\left[\log \frac{p_\theta(x|z) p(z)}{q_\phi(z|x)}\right] \quad \text{(Jensen)} \\ &= \underbrace{\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]}_{\text{reconstruction (negative)}} - \underbrace{D_\text{KL}(q_\phi(z|x)\,\|\,p(z))}_{\text{regularization}} \end{aligned} $$
故 ELBO:
$$\boxed{\;\mathcal{L}_\text{ELBO}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_\text{KL}\!\big(q_\phi(z|x)\,\|\,p(z)\big)\;}$$
Tight 的代价:ELBO 与真 log-likelihood 的 gap 等于 $D_\text{KL}(q_\phi(z|x)\,\|\,p_\theta(z|x))$。posterior $q_\phi$ 越逼近真后验 $p_\theta(z|x)$,gap 越小。
2.2 KL 项闭式推导(L3 必推)
设 $q_\phi(z|x) = \mathcal{N}(\mu, \mathrm{diag}(\sigma^2))$(对角协方差,每维 $\sigma_i^2$),$p(z) = \mathcal{N}(0, I)$。
对每一维 $i$ 独立:
$$ \begin{aligned} D_\text{KL}(\mathcal{N}(\mu_i, \sigma_i^2) \,\|\, \mathcal{N}(0, 1)) &= \int \mathcal{N}(z; \mu_i, \sigma_i^2) \log \frac{\mathcal{N}(z; \mu_i, \sigma_i^2)}{\mathcal{N}(z; 0, 1)} dz \end{aligned} $$
展开两个 Gaussian 密度的 log:
$$ \log \frac{\mathcal{N}(z; \mu_i, \sigma_i^2)}{\mathcal{N}(z; 0, 1)} = -\tfrac{1}{2}\log \sigma_i^2 - \tfrac{(z-\mu_i)^2}{2\sigma_i^2} + \tfrac{z^2}{2} $$
求期望(利用 $\mathbb{E}_q[z] = \mu_i$, $\mathbb{E}_q[z^2] = \mu_i^2 + \sigma_i^2$, $\mathbb{E}_q[(z-\mu_i)^2] = \sigma_i^2$):
$$ \begin{aligned} D_\text{KL} &= -\tfrac{1}{2}\log \sigma_i^2 - \tfrac{1}{2} + \tfrac{1}{2}(\mu_i^2 + \sigma_i^2) \\ &= \tfrac{1}{2}\big(\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1\big) \end{aligned} $$
对所有维度求和:
$$\boxed{\;D_\text{KL}\big(\mathcal{N}(\mu, \mathrm{diag}(\sigma^2)) \,\|\, \mathcal{N}(0, I)\big) = \tfrac{1}{2}\sum_{i=1}^{d}\big(\mu_i^2 + \sigma_i^2 - \log \sigma_i^2 - 1\big)\;}$$
实现时让 encoder 输出 $\log \sigma^2$(log-variance)而非 $\sigma$,避免 $\sigma$ 取 exp 后 overflow。代码里写 kl = 0.5 * (mu**2 + logvar.exp() - logvar - 1).sum()。
2.3 Reparameterization Trick(必考)
ELBO 中 $\mathbb{E}_{q_\phi(z|x)}[\cdot]$ 用 Monte Carlo 估计:采一个 $z \sim q_\phi(z|x)$,求 $\log p_\theta(x|z)$。
问题:直接采样 $z = \text{sample}(\mathcal{N}(\mu, \sigma^2))$ 是不可导操作,梯度无法回传到 $\phi$。
解法:把随机性挪到独立噪声里:
$$\boxed{\;z = \mu_\phi(x) + \sigma_\phi(x) \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\;}$$
现在 $z$ 是 $\phi$ 的确定性函数(条件于 $\epsilon$),$\nabla_\phi \mathcal{L}$ 可以正常反向传播。这是 Kingma & Welling (ICLR 2014) 的核心贡献之一。
Concrete / Gumbel-softmax(§7)对离散变量做了类似 trick:把 argmax 替换成 softmax 加 Gumbel 噪声,前向近似离散,反向用 softmax 梯度。
2.4 VAE 训练损失(实际写法)
负 ELBO(最小化):
$$\mathcal{L}_\text{VAE}(x) = \underbrace{\|x - \hat{x}\|^2}_{\text{recon (Gaussian likelihood up to const)}} + \underbrace{D_\text{KL}(q_\phi(z|x)\,\|\,p(z))}_{\text{closed form}}$$
对 Bernoulli / Categorical 似然(如 MNIST 二值),recon 项换成 BCE / CE。
§3 完整 VAE 实现(PyTorch)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAE(nn.Module):
""" 经典 VAE:Gaussian encoder + Gaussian/Bernoulli decoder
本实现以 MNIST (28×28) 为例,latent dim=20
生产化把 MLP 换成 ResNet / U-Net encoder/decoder,latent 可以是 spatial map """
def __init__(self, x_dim: int = 784, h_dim: int = 400, z_dim: int = 20,
likelihood: str = "bernoulli"):
super().__init__()
self.x_dim, self.z_dim = x_dim, z_dim
self.likelihood = likelihood
# Encoder: x -> (μ, logσ²)
self.enc = nn.Sequential(
nn.Linear(x_dim, h_dim), nn.ReLU(),
nn.Linear(h_dim, h_dim), nn.ReLU(),
)
self.fc_mu = nn.Linear(h_dim, z_dim)
self.fc_logvar = nn.Linear(h_dim, z_dim)
# Decoder: z -> x̂
self.dec = nn.Sequential(
nn.Linear(z_dim, h_dim), nn.ReLU(),
nn.Linear(h_dim, h_dim), nn.ReLU(),
nn.Linear(h_dim, x_dim),
)
def encode(self, x: torch.Tensor):
h = self.enc(x.view(x.size(0), -1))
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor):
# z = μ + σ ⊙ ε, σ = exp(0.5 · logvar)
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + std * eps
else:
# 推理时用 posterior mean (deterministic)
return mu
def decode(self, z: torch.Tensor):
logits = self.dec(z)
if self.likelihood == "bernoulli":
return torch.sigmoid(logits), logits
return logits, logits # Gaussian likelihood: 视为 mean prediction
def forward(self, x: torch.Tensor):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_hat, logits = self.decode(z)
return x_hat, logits, mu, logvar
def vae_loss(x: torch.Tensor, logits: torch.Tensor, mu: torch.Tensor,
logvar: torch.Tensor, likelihood: str = "bernoulli",
beta: float = 1.0, free_bits: float = 0.0):
""" Returns:
(loss, recon, kl)
beta: β-VAE 的 β(默认 1 = 标准 VAE)
free_bits: 每维 KL 下限(nats)。 > 0 时启用 free bits"""
B = x.size(0)
x_flat = x.view(B, -1)
# 1) Reconstruction term: -E_q[log p(x|z)]
if likelihood == "bernoulli":
# BCE-with-logits 数值上更稳,等价于 -log Bernoulli likelihood
recon = F.binary_cross_entropy_with_logits(
logits, x_flat, reduction="sum") / B
elif likelihood == "gaussian":
# 假设 σ² = 1(常数),MSE 与负 log-Gaussian 差一个常数
recon = 0.5 * F.mse_loss(logits, x_flat, reduction="sum") / B
else:
raise ValueError(likelihood)
# 2) KL term: D_KL(N(μ, σ²) || N(0, I)) 闭式
kl_per_dim = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1) # [B, z_dim]
if free_bits > 0:
# Free bits: 每维 KL 下限 = free_bits(缓解 posterior collapse)
kl_per_dim = torch.clamp(kl_per_dim, min=free_bits)
kl = kl_per_dim.sum(dim=-1).mean() # 标量
loss = recon + beta * kl
return loss, recon, kl
写 VAE 时容易踩的坑。
reparameterize写成mu + logvar * eps,忘了 $\sigma = \exp(0.5 \cdot \log\sigma^2)$- KL 里
0.5 * (mu**2 + sigma**2 - 2*log_sigma - 1),注意是 $-\log \sigma^2 = -2\log\sigma$ - BCE 写
F.binary_cross_entropy(sigmoid(logits), x)而不是F.binary_cross_entropy_with_logits(logits, x),前者数值上不稳 - Reduction 不一致:recon 用
sum、KL 用mean,导致 $\beta$ 实际 scale 漂移
3.1 训练循环 + KL annealing
def train_vae(model, dataloader, total_steps=50_000, lr=1e-3, device="cuda",
beta_max=1.0, anneal_steps=10_000, free_bits=0.0):
""" KL annealing: β 从 0 线性增到 beta_max,防止训练早期 posterior 直接塌缩 """
opt = torch.optim.Adam(model.parameters(), lr=lr)
model.to(device).train()
step = 0
while step < total_steps:
for x, _ in dataloader:
x = x.to(device)
beta = min(beta_max, beta_max * step / max(anneal_steps, 1))
x_hat, logits, mu, logvar = model(x)
loss, recon, kl = vae_loss(x, logits, mu, logvar,
beta=beta, free_bits=free_bits)
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
opt.step()
step += 1
if step >= total_steps:
break
§4 VAE 变体:$\beta$-VAE / IWAE / NVAE / VAE-GAN
4.1 $\beta$-VAE(Higgins et al., ICLR 2017)
把 ELBO 的 KL 加权重 $\beta$:
$$\mathcal{L}_{\beta\text{-VAE}}(\theta, \phi; x) = \mathbb{E}_{q_\phi}[\log p_\theta(x|z)] - \beta \cdot D_\text{KL}(q_\phi(z|x)\,\|\,p(z))$$
- $\beta > 1$:更强地 push posterior → 先验,鼓励 disentangled representation(每维 $z$ 控制一个独立因素,如 dSprites 上的位置、形状、旋转)。
- $\beta < 1$:放宽 KL,重建更精但 latent 不够 prior-like(采样质量差)。
- $\beta = 1$ 退化为标准 VAE。
Locatello et al. (ICML 2019, 最佳论文) 证明:纯 unsupervised disentanglement 在没有归纳偏置 / 监督的前提下是不可能的。$\beta$-VAE 的"涌现 disentanglement"很大程度依赖架构 + 数据集 bias,而不是 $\beta$ 本身。
4.2 IWAE:Importance Weighted Autoencoder(Burda et al., ICLR 2016)
ELBO 是 $\log p(x)$ 的一阶 bound。用 $K$ 个 importance 样本得到更紧的 bound:
$$\mathcal{L}_K^\text{IWAE}(x) = \mathbb{E}_{z_1,\ldots,z_K \sim q_\phi}\!\left[\log \frac{1}{K}\sum_{k=1}^K \frac{p_\theta(x, z_k)}{q_\phi(z_k|x)}\right]$$
性质:
- $\mathcal{L}_1^\text{IWAE} = $ ELBO(特例)。
- $\mathcal{L}_K^\text{IWAE} \to \log p(x)$ 当 $K \to \infty$(Burda 定理)。
- $K$ 越大,inference 越 expressive(但训练 cost 也 $\times K$)。
IWAE 让 likelihood bound 更紧,但 encoder 学到的 posterior 不再追求"逼近真后验",转去配合 importance weighting 的几何。对下游 representation learning 不一定更好。
4.3 NVAE:Hierarchical VAE(Vahdat & Kautz, NeurIPS 2020)
多层 latent $z = (z_1, z_2, \ldots, z_L)$,每层依赖前层:
$$p(z) = p(z_1)\prod_{l=2}^L p(z_l | z_{ 工程要点: NVAE 现在的角色:LDM 之前最强连续 VAE prior 之一,但被 diffusion 系列在 sample 质量上反超。 VAE 的重建损失(pixel-MSE / BCE)对高频细节不敏感 → 生成图模糊。VAE-GAN 把 MSE 换成 / 补上 discriminator feature matching: $$\mathcal{L}_\text{recon}^\text{VAE-GAN} = \|D_l(x) - D_l(\hat{x})\|^2$$ 其中 $D_l$ 是 discriminator 的中间层特征。配合 adversarial loss,重建感更清晰。 这套思路最终在 VQ-GAN(§6)里大成:VQ-VAE 框架 + perceptual + adversarial + 高码率 codebook + Transformer prior。 训练中 $D_\text{KL}(q_\phi(z|x)\,\|\,p(z)) \to 0$,即 $q_\phi(z|x) \approx p(z)$,与 $x$ 无关。后果:decoder 完全忽略 $z$,VAE 退化为 unconditional generative model。 实现极简: Codebook $\mathcal{E} = \{e_1, \ldots, e_K\} \subset \mathbb{R}^D$,学习得到。$z_e(x)$ 与 $z_q(x)$ 形状一致,但 $z_q$ 每个空间位置都是 codebook 中某一个 vector 的 copy(离散 index $k_{hw}$)。 VQ-VAE 不学随机 posterior $q(z|x)$(不像 VAE),而是用确定性最近邻做 $z_e \to z_q$ 的"量化"。loss 由三部分组成: $$\boxed{\;\mathcal{L}_\text{VQ-VAE} = \underbrace{\|x - \hat{x}\|^2}_{\text{reconstruction}} + \underbrace{\|\text{sg}[z_e(x)] - e\|^2}_{\text{codebook}} + \beta \underbrace{\|z_e(x) - \text{sg}[e]\|^2}_{\text{commitment}}\;}$$ 各项含义: 如果都不 sg,$\|z_e - e\|^2$ 同时拉两侧,方向耦合容易振荡。把这一项拆成两个 sg 版本:codebook 项专门更新 $e$,commitment 专门更新 encoder,学习率 / 速度可解耦。这是 vector quantization 文献的标准做法(也叫 "alternating minimization")。 问题:$z_q = e_{\arg\min_k \|z_e - e_k\|^2}$ 的 STE 解法: PyTorch 实现技巧(经典三行): 前向: STE 等价于把不可导的 $z_q = \text{quantize}(z_e)$ 替换成可导 surrogate $z_q^\text{surrogate} = z_e$ 来反传——即"假设量化是恒等映射"。这是一个有偏估计(biased gradient estimator),但实践中工作良好;理论分析见 Bengio et al. (2013) "Estimating or Propagating Gradients Through Stochastic Neurons"。 直接用 codebook loss 更新 $e$ 收敛慢,dead codes(从来不被选中的 codebook 向量)多。生产实现用 EMA (Exponential Moving Average) 更新: 对每个 codebook 向量 $e_k$,维护: 更新: $$e_k^{(t)} = \frac{m_k^{(t)}}{N_k^{(t)} + \varepsilon} \quad \text{(Laplace smoothing)}$$ VQ-VAE 的层次扩展: 在 ImageNet 256×256 上首次让 VQ-based 方法接近 BigGAN 的样本质量,是 VQ-GAN 的直接前作。 VQ-VAE 在 ImageNet 上重建的纹理细节模糊。VQ-GAN 改造为: $$\mathcal{L}_\text{VQ-GAN}^\text{stage1} = \mathcal{L}_\text{rec} + \mathcal{L}_\text{VQ} + \lambda \cdot \mathcal{L}_\text{GAN}$$ 其中: $$ \begin{aligned} \mathcal{L}_\text{rec} &= \|x - \hat{x}\|_1 + \mathcal{L}_\text{LPIPS}(x, \hat{x}) \\ \mathcal{L}_\text{VQ} &= \|\text{sg}[z_e] - e\|^2 + \beta \|z_e - \text{sg}[e]\|^2 \end{aligned} $$ Generator/Tokenizer 阶段的 GAN 项(只对 generator 的输出反传,discriminator 在另一阶段单独更新): $$\mathcal{L}_\text{GAN}^{(G)} = -\mathbb{E}_{\hat{x}}[\log D(\hat{x})]\quad\text{(non-saturating)}\quad\text{或}\quad \mathcal{L}_\text{GAN}^{(G)} = -\mathbb{E}_{\hat{x}}[D(\hat{x})]\quad\text{(hinge)}$$ Discriminator 自身的 minimax 项(独立 step 更新 $D$): $$\mathcal{L}_\text{GAN}^{(D)} = -\mathbb{E}_x[\min(0, -1+D(x))] - \mathbb{E}_{\hat{x}}[\min(0, -1-D(\hat{x}))]\quad\text{(hinge)}$$ 自适应 $\lambda$(论文创新点,用最后一层梯度范数比自动平衡,避免人工调参): $$\lambda = \frac{\lVert\nabla_{G_L} \mathcal{L}_\text{rec}\rVert}{\lVert\nabla_{G_L} \mathcal{L}_\text{GAN}^{(G)}\rVert + \delta}$$ $G_L$ 是 decoder 最后一层;$\lVert\cdot\rVert$ 是 Frobenius 范数。总 generator loss: $$\mathcal{L}_G = \mathcal{L}_\text{rec} + \mathcal{L}_\text{VQ} + \lambda \cdot \mathcal{L}_\text{GAN}^{(G)}$$ Stage 1 训好 VQ-GAN,把图像转成 token grid $\mathbf{c} = (c_1, \ldots, c_{HW})$(行扫描展平)。Stage 2 在 token sequence 上训 decoder-only Transformer,标准 AR: $$p(\mathbf{c}) = \prod_{i=1}^{HW} p(c_i | c_{
采样:AR sample tokens → VQ-GAN decoder → image。这是把"图像生成"翻译成"语言模型"的标准范式,DALL·E / Parti / Muse 都是这一思想的进化。 Stable Diffusion 的 VAE tokenizer 实际是 KL-regularized VQ-GAN 的连续 latent 变体(去掉 quantization,留 KL + perceptual + adversarial),output 是 continuous latent map(4 通道,下采样 8×)。diffusion 在这个 latent 上跑,最后 decoder 还原。可以理解为 "VQ-GAN encoder/decoder + continuous latent + KL"。 VQ-GAN 用 PatchGAN(Isola et al. CVPR 2017 "pix2pix"): $$\mathcal{L}_\text{LPIPS}(x, \hat{x}) = \sum_l w_l \cdot \|\phi_l(x) - \phi_l(\hat{x})\|^2$$ $\phi_l$ 是预训练 VGG / AlexNet 的第 $l$ 层 feature map,$w_l$ 是学到的 channel-wise weight(Zhang et al. CVPR 2018)。比 pixel-MSE 更贴近人类感知,是 VQ-GAN / SD / 大部分 image GAN / diffusion 训练的标配。 DALL·E 用 dVAE (discrete VAE) 作为 image tokenizer: 目标:让 categorical sampling 可导。 Gumbel-Max trick:对 logits $\pi = (\pi_1, \ldots, \pi_K)$ 加独立 Gumbel(0,1) 噪声 $g_k = -\log(-\log u_k), u_k \sim \mathcal{U}(0, 1)$,则: $$\arg\max_k \{\log \pi_k + g_k\}$$ 服从 categorical(softmax($\pi$))。证明用 Gumbel 分布的 CDF 性质:$P(\max_k X_k = X_j) = e^{\pi_j} / \sum_k e^{\pi_k}$。 Gumbel-softmax (Jang et al., ICLR 2017; Maddison et al., ICLR 2017 同期):把不可导的 argmax 替换成 softmax: $$\boxed{\;y_k = \frac{\exp((\log \pi_k + g_k) / \tau)}{\sum_j \exp((\log \pi_j + g_j) / \tau)}\;}$$ Straight-Through Gumbel-Softmax:前向用 argmax(离散),反向用 softmax 梯度——类似 VQ-VAE 的 STE 思路。 都是离散 latent 模型,区别: 把 AR Transformer prior 换成 BERT-style masked Transformer: 后继:MUSE (Chang et al., 2023) 把同思路推到 text-to-image,是 Google 主线生成模型之一。 VQ-VAE 有顽疾: FSQ 用一招把这一堆问题绕过去:逐维标量量化(scalar quantization, 不是 vector quantization)。 让 encoder 输出 $z \in \mathbb{R}^d$。对每一维独立做 scalar quantization(FSQ 论文 Eq. 4): $$z_i \longrightarrow z'_i = \tfrac{L_i-1}{2}\tanh(z_i) - s_i \longrightarrow \hat{z}_i = \text{round}(z'_i) + s_i$$ 其中 $s_i = 0$ 若 $L_i$ 奇,$s_i = 0.5$ 若 $L_i$ 偶。这样: 无论奇偶,每维都得到 $L_i$ 个 level;把所有维度的 level 数乘起来: $$\boxed{\;K_\text{implicit} = \prod_{i=1}^{d} L_i\;}$$ 在 VQ-VAE 中,codebook collapse 的根源是:codebook 是自由参数,optimizer 让大部分 $e_k$ 都 drift 到无用区域,只有少数 $e_k$ 被反复用。FSQ 的"codebook"不是参数——它是数轴上的固定 grid 点($\{-L/2, \ldots, L/2\}$)。 实证:FSQ 的 codebook usage 几乎是 100%(与 VQ-VAE 50-70% 形成对比),ImageNet / Cosmos / OpenMagViT2 复现均有此结论。 $$\mathcal{L}_\text{FSQ} = \|x - \hat{x}\|^2 \quad \text{(plus optional perceptual + adversarial)}$$ FSQ 是"用空间维度换 codebook 大小":VQ-VAE 用 1 维($D$ 个连续值 + 1 个离散选择 from $K$),FSQ 用 $d$ 个独立离散维度,每维 $L_i$ 个 level,最终离散 entropy 反而更大、collapse 几乎不可能。代价:embedding 表达力略弱(每维独立,不共享 representation),不过 reconstruction 端通过 decoder 已补回。 FSQ 的二值特例: $$\text{LFQ}(z) = \text{sign}(z) \in \{-1, +1\}^d$$ 每维只有 2 个 level,隐式 codebook = $2^d$:$d=18$ 时 codebook = $2^{18} = 262{144}$(与 FSQ 等价的量级)。 特点: 工程要点: Blau & Michaeli (ICML 2018) 证明:重建(MSE / PSNR)和感知(perceptual / FID)之间存在严格的 Pareto 边界。VQ-GAN / SD VAE 引入 LPIPS + adversarial 是为了交换更高 perceptual 质量而接受略差的 PSNR。 VQ-GAN 论文里 PSNR 不一定优于 VQ-VAE,但 perceptual (LPIPS / FID) 远好。面试常被反问"为什么 SOTA tokenizer 的 PSNR 反而下降"——这是 distortion-perception tradeoff。 codex (gpt-5.5 xhigh) 作为顶级 lab 面试官视角列的,按难度分 3 档。每题点开看答案要点 + 易踩坑。 写成 $\log p(x|z) - D_\text{KL}(...)$(漏了期望符号);只说"重建+正则化"不写公式。 说"为了加速训练"——其实是为了可导。 只说 "$\beta$ 越大越 disentangle"——错,取决于数据 + 架构。 只说"latent 没用",不说 KL → 0 这个量化指标。 说 codebook 是固定的 / 预训练的——错,端到端学习。 把 codebook 和 commitment loss 混作一项;忘了 sg 的方向。 只说"反向用 identity",不说前向还是真量化。 只说 "GAN" 不说 perceptual;或忘 Transformer prior。 把 FSQ 当作 VQ-VAE 的 codebook 优化技巧——错,FSQ 没有显式 codebook 参数。 把 $\log \sigma^2$ 与 $\log \sigma$ 搞混(差一个 2);或忘 "-1" 项。 说"$K = 1$ 也比 ELBO 强"——错,特例就是 ELBO。 只答"KL annealing" 一种;或把 $\beta$-VAE 当成缓解 collapse 的工具(其实 $\beta > 1$ 更易 collapse)。 跳过推导直接背公式;忘记 $\mathbb{E}_q[z^2]$ 的展开。 如果两个都不 sg,等价于普通 MSE,learning rate 实际加倍 + 两侧互相拉扯 误以为 sg 是为了"防止 codebook 更新太快"——其实是为了梯度解耦。 说 STE 是无偏估计——错,是 biased。 把 EMA 当作 momentum + Adam 的 SGD 变体——本质是 k-means 在 mini-batch 下的 EMA 估计。 误以为 PatchGAN 是 attention-based;或说它只在 image-to-image translation 用。 只说"用 VGG feature",不说 learned channel weights / human study 拟合。 把 Gumbel(0,1) 写成正态噪声;忘记 $\arg\max$ 的概率正比 softmax。 误以为 MaskGIT 是 diffusion 的离散版——其实是 BERT MLM 的生成扩展。 把 FSQ 与 LFQ 混为一谈(LFQ 是 $L = 2$ binary 特例);或以为 FSQ 取消了 STE(其实 round 仍需 STE,只是不需要额外 codebook / commitment loss)。 诊断:测 perplexity = $\exp(-\sum_k p_k \log p_k)$,$p_k$ 是 codebook 第 $k$ 个 code 的使用频率 缓解: 只答"用大 codebook"——错,更大 codebook 反而更容易 collapse。 混淆 SD 的 VAE 和 VQ-VAE——SD 用的 VAE 没有 quantization。 只答"用了 ResNet 架构",没说概率层面的 residual normal。 只答"PSNR 不是好指标",没说背后是严格的 Pareto bound。 参考 from-scratch 实现包含: 实跑 sanity check 输出(PyTorch 2.x,单机 GPU): 代码经独立 reviewer 静态检查 + PyTorch 实跑 sanity check: VAE / VQ-VAE / VQ-GAN / FSQ Quick Reference · 主要参考:Kingma & Welling 2014 (VAE), Higgins et al. 2017 ($\beta$-VAE), van den Oord et al. 2017 (VQ-VAE), Razavi et al. 2019 (VQ-VAE-2), Esser et al. 2021 (VQ-GAN), Ramesh et al. 2021 (DALL·E / dVAE), Chang et al. 2022 (MaskGIT), Mentzer et al. 2024 (FSQ), Yu et al. 2024 (MAGVIT-v2 / LFQ)4.4 VAE-GAN(Larsen et al., ICML 2016)
§5 Posterior Collapse(必考)
5.1 现象
5.2 原因(直观分析)
5.3 缓解方法(面试要会列)
方法 做法 出处 KL annealing $\beta(t) = \min(1, t / T)$ 线性从 0 升到 1 Bowman et al. (2016) Free bits 每维 KL 下限 $\lambda$ nats:$\max(D_\text{KL}^{(i)}, \lambda)$ Kingma et al. (2016) $\beta$ < 1 直接减小 KL 权重 $\beta$-VAE 反向用法 Weakened decoder 用 PixelCNN 等强 AR decoder 时人为 truncate context / 加 dropout Chen et al. (2017) Auxiliary task 加 word dropout、bag-of-words 辅助 loss Bowman et al. (2016) VAE-IAF / NF prior 用更复杂的先验或 normalizing flow posterior Kingma et al. (2016) Skip / lateral 连接 让 latent 强制参与 decoder(如 VLAE) Zhao et al. (2017) VQ-VAE 离散 latent 配合 codebook commitment,结构上避免 collapse van den Oord (2017) kl_per_dim = max(kl_per_dim, λ)。直觉:给每维 latent 保底 λ 比特信息,optimizer 不能把它压到 0 以下。$\lambda \approx 0.5$-$2$ nats / dim 是常见值。§6 VQ-VAE:离散 latent + Codebook + STE
6.1 结构(van den Oord, Vinyals, Kavukcuoglu, NeurIPS 2017)
x ──Encoder──> z_e(x) ∈ R^{H'×W'×D} # continuous spatial map
│
│ 对每个空间位置 (h,w),找最近 codebook 向量
│ k_{hw} = argmin_k ‖z_e(x)_{hw} - e_k‖²
↓
z_q(x)_{hw} = e_{k_{hw}} # quantized spatial map (discrete code)
│
│
↓
Decoder ──> x̂6.2 Loss 推导
sg 阻断对 $z_e$ 的梯度,否则 codebook 与 encoder 都被拉,方向不清晰)。sg[·] = stop_gradient(PyTorch 里 .detach()),定义:前向 $\text{sg}[u] = u$,反向 $\nabla \text{sg}[u] = 0$。6.3 Straight-Through Estimator (STE) 推导
argmin 不可导(输出离散 index)。z_q = z_e + (z_q_quantized - z_e).detach()z_q = z_e + (z_q_q - z_e) = z_q_q ✓(量化值) 反向:(z_q_q - z_e).detach() 不参与梯度,所以 dz_q/dz_e = 1,梯度直通到 encoder ✓6.4 VQ-VAE 完整实现
class VectorQuantizer(nn.Module):
""" Codebook + 最近邻量化 + STE
embedding_dim = D, num_embeddings = K
commitment_cost β 通常 = 0.25 """
def __init__(self, num_embeddings: int, embedding_dim: int,
commitment_cost: float = 0.25):
super().__init__()
self.K, self.D = num_embeddings, embedding_dim
self.beta = commitment_cost
# Codebook 用 small uniform init
self.codebook = nn.Embedding(self.K, self.D)
self.codebook.weight.data.uniform_(-1.0 / self.K, 1.0 / self.K)
def forward(self, z_e: torch.Tensor):
""" z_e: [B, D, H, W] -> z_q: [B, D, H, W], indices: [B, H, W],
loss = codebook_loss + β·commitment_loss """
# 1) Reshape: [B, D, H, W] -> [BHW, D]
B, D, H, W = z_e.shape
z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D) # [BHW, D]
# 2) 计算 L2 距离 ‖z_e - e_k‖² = ‖z_e‖² + ‖e_k‖² - 2 z_e · e_k
e = self.codebook.weight # [K, D]
dist = (z_e_flat.pow(2).sum(1, keepdim=True)
+ e.pow(2).sum(1)
- 2 * z_e_flat @ e.t()) # [BHW, K]
# 3) 最近邻 index
indices = dist.argmin(dim=1) # [BHW]
z_q_flat = self.codebook(indices) # [BHW, D]
# 4) 损失(注意 sg)
codebook_loss = F.mse_loss(z_q_flat, z_e_flat.detach())
commitment_loss = F.mse_loss(z_e_flat, z_q_flat.detach())
vq_loss = codebook_loss + self.beta * commitment_loss
# 5) STE:前向 z_q,反向 dz_q/dz_e = I
z_q_flat = z_e_flat + (z_q_flat - z_e_flat).detach()
# 6) Reshape 回 [B, D, H, W]
z_q = z_q_flat.view(B, H, W, D).permute(0, 3, 1, 2).contiguous()
indices = indices.view(B, H, W)
# 7) (可选) perplexity: codebook 使用度的衡量
one_hot = F.one_hot(indices.view(-1), self.K).float()
avg_probs = one_hot.mean(dim=0)
perplexity = torch.exp(-(avg_probs * torch.log(avg_probs + 1e-10)).sum())
return z_q, indices, vq_loss, perplexity
class VQVAE(nn.Module):
def __init__(self, channels=3, hidden=128, num_embeddings=512, embedding_dim=64,
commitment_cost=0.25):
super().__init__()
# Encoder: 64×64×3 -> 16×16×D (downsample 4×)
self.encoder = nn.Sequential(
nn.Conv2d(channels, hidden, 4, 2, 1), nn.ReLU(),
nn.Conv2d(hidden, hidden, 4, 2, 1), nn.ReLU(),
nn.Conv2d(hidden, embedding_dim, 3, 1, 1),
)
self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
# Decoder: 16×16×D -> 64×64×3
self.decoder = nn.Sequential(
nn.ConvTranspose2d(embedding_dim, hidden, 3, 1, 1), nn.ReLU(),
nn.ConvTranspose2d(hidden, hidden, 4, 2, 1), nn.ReLU(),
nn.ConvTranspose2d(hidden, channels, 4, 2, 1),
)
def forward(self, x):
z_e = self.encoder(x)
z_q, indices, vq_loss, perplexity = self.quantizer(z_e)
x_hat = self.decoder(z_q)
return x_hat, vq_loss, perplexity, indices
def vqvae_loss(x, x_hat, vq_loss):
recon = F.mse_loss(x_hat, x)
return recon + vq_loss, recon6.5 EMA Codebook(生产标准做法)
class VectorQuantizerEMA(nn.Module):
""" EMA codebook update (van den Oord 2017 后续 / VQ-VAE-2 标准做法)
- codebook 不靠 gradient,靠 running EMA 更新
- 留 commitment loss 用于 encoder
- decay γ 一般 0.99, ε 一般 1e-5 """
def __init__(self, num_embeddings: int, embedding_dim: int,
commitment_cost: float = 0.25, decay: float = 0.99, eps: float = 1e-5):
super().__init__()
self.K, self.D = num_embeddings, embedding_dim
self.beta, self.decay, self.eps = commitment_cost, decay, eps
embed = torch.randn(num_embeddings, embedding_dim) * 0.01
self.register_buffer("codebook", embed)
self.register_buffer("cluster_size", torch.zeros(num_embeddings))
self.register_buffer("embed_avg", embed.clone())
def forward(self, z_e):
B, D, H, W = z_e.shape
z_e_flat = z_e.permute(0, 2, 3, 1).contiguous().view(-1, D)
dist = (z_e_flat.pow(2).sum(1, keepdim=True)
+ self.codebook.pow(2).sum(1)
- 2 * z_e_flat @ self.codebook.t())
indices = dist.argmin(dim=1) # [BHW]
z_q_flat = F.embedding(indices, self.codebook) # [BHW, D]
if self.training:
# EMA 更新
with torch.no_grad():
one_hot = F.one_hot(indices, self.K).float() # [BHW, K]
cluster_size_new = one_hot.sum(dim=0) # [K]
embed_sum = one_hot.t() @ z_e_flat # [K, D]
self.cluster_size.mul_(self.decay).add_(cluster_size_new, alpha=1 - self.decay)
self.embed_avg.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
# Laplace smoothing 避免被零除
n = self.cluster_size.sum()
cluster_size = (self.cluster_size + self.eps) / (n + self.K * self.eps) * n
self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1))
commitment_loss = F.mse_loss(z_e_flat, z_q_flat.detach())
vq_loss = self.beta * commitment_loss # EMA 下不要 codebook loss
z_q_flat = z_e_flat + (z_q_flat - z_e_flat).detach() # STE
z_q = z_q_flat.view(B, H, W, D).permute(0, 3, 1, 2).contiguous()
return z_q, indices.view(B, H, W), vq_loss6.6 VQ-VAE-2(Razavi, Vinyals, van den Oord, NeurIPS 2019)
§7 VQ-GAN:Adversarial + Perceptual + Transformer Prior
7.1 核心思想(Esser, Rombach, Ommer, CVPR 2021, "Taming Transformers")
组件 VQ-VAE VQ-GAN Recon loss L2 / L1 pixel L1 pixel + LPIPS perceptual + PatchGAN adversarial Prior PixelCNN Transformer (decoder-only) over code tokens Codebook 512-1024 codes 1024-16384 codes Compression 4×-8× 8×-32×(更高压缩比,靠 perceptual + adversarial 救质量) 应用 unconditional / class-cond 生成 high-res image synthesis, "Taming Transformers" 7.2 Loss 公式
7.3 Stage 2:Transformer Prior
7.4 PatchGAN Discriminator(生产架构)
class PatchDiscriminator(nn.Module):
""" PatchGAN: 70×70 receptive field 的 stack of strided convs
输出 [B, 1, H/8, W/8] 的 patch-level real/fake 判别 """
def __init__(self, in_ch=3, hidden=64, n_layers=3):
super().__init__()
layers = [nn.Conv2d(in_ch, hidden, 4, 2, 1), nn.LeakyReLU(0.2, True)]
ch = hidden
for i in range(1, n_layers):
ch_next = min(hidden * (2 ** i), 512)
layers += [
nn.Conv2d(ch, ch_next, 4, 2, 1),
nn.BatchNorm2d(ch_next),
nn.LeakyReLU(0.2, True),
]
ch = ch_next
layers += [
nn.Conv2d(ch, ch * 2, 4, 1, 1),
nn.BatchNorm2d(ch * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ch * 2, 1, 4, 1, 1),
]
self.net = nn.Sequential(*layers)
def forward(self, x): return self.net(x)
def hinge_d_loss(real_logits, fake_logits):
real = F.relu(1.0 - real_logits).mean()
fake = F.relu(1.0 + fake_logits).mean()
return 0.5 * (real + fake)
def hinge_g_loss(fake_logits):
return -fake_logits.mean()7.5 LPIPS(Perceptual Loss)
§8 离散 VAE 与 Gumbel-Softmax
8.1 dVAE(DALL·E 1,Ramesh et al. ICML 2021)
8.2 Gumbel-Softmax 推导
def gumbel_softmax_sample(logits, tau=1.0, hard=False, dim=-1):
""" 输入 logits = log π 输出 soft / hard one-hot """
# 1) 加 Gumbel 噪声
g = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
y_soft = F.softmax((logits + g) / tau, dim=dim)
if not hard:
return y_soft
# ST: 前向 hard, 反向 soft 梯度
index = y_soft.argmax(dim=dim, keepdim=True)
y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
y = y_hard - y_soft.detach() + y_soft # straight-through
return y8.3 MaskGIT(Chang et al., CVPR 2022)
§9 FSQ:Finite Scalar Quantization(重点)
9.1 动机(Mentzer, Minnen, Agustsson, Toderici, ICLR 2024)
9.2 核心公式(必推)
9.3 为什么 FSQ 不会 codebook collapse?
9.4 为什么 FSQ 不需要"显式 STE 包装"以及 loss 也极简
x_hat = x + (round(x) - x).detach(),没有 codebook loss / commitment loss / EMA / dead-code revival。9.5 FSQ 实现(10 行)
class FSQ(nn.Module):
""" Finite Scalar Quantization (Mentzer et al., ICLR 2024)
levels: tuple, 每维量化 level 数(必须为奇数或偶数都行,奇数保证含 0)
eps: bounding 安全裕度,避免 tanh 后 round 跳出 grid """
def __init__(self, levels=(8, 5, 5, 5)):
super().__init__()
levels_t = torch.tensor(levels, dtype=torch.float32)
self.levels = levels_t
self.d = len(levels)
self.K = int(torch.prod(levels_t).item()) # 隐式 codebook size = ∏ L_i
# FSQ paper Eq. 4: half = (L-1)/2; shift = 0.5 if L 偶 else 0
half = (levels_t - 1) / 2 # [d]
shift = ((levels_t % 2) == 0).float() * 0.5 # [d]
self.register_buffer("half_l", half)
self.register_buffer("shift", shift)
# mixed-radix basis for token id encoding
cumprod = torch.tensor([1.0] + list(torch.cumprod(levels_t[:-1], dim=0)),
dtype=torch.float32)
self.register_buffer("basis", cumprod) # [d]
@staticmethod
def round_ste(z):
"""不可导 round 的 STE: 前向 round, 反向 identity"""
return z + (z.round() - z).detach()
def forward(self, z):
""" z: [B, d, ...] -> z_hat: [B, d, ...] (量化值), codes: [B, ...] (∈ 0..K-1) """
view = (1, -1) + (1,) * (z.dim() - 2)
half = self.half_l.view(*view).to(z.device)
shift = self.shift.view(*view).to(z.device)
# 1) Bound: tanh(z) * half - shift → z'∈[-half-shift, half-shift]
z_bounded = torch.tanh(z) * half - shift
# 2) Round (STE) + 加回 shift → 奇 L 得 {-half,…,half}(整数),偶 L 得 {-half,…,half}(含半整数)
z_hat = self.round_ste(z_bounded) + shift
# 3) Token ID (mixed-radix):把 d 维 ∈ {-half_i,…,half_i} 映成 0..L_i-1 再编成单一 index
shifted = (z_hat + half).round().long() # ∈ 0..L_i-1(round 兜底浮点误差)
basis = self.basis.view(*view).to(z.device).long()
codes = (shifted * basis).sum(dim=1) # [B, ...]
return z_hat, codes
# 使用示例:
# fsq = FSQ(levels=(8, 5, 5, 5)) # K = 8·5·5·5 = 1000
# z = encoder(x) # [B, 4, H, W]
# z_hat, tokens = fsq(z) # z_hat: [B, 4, H, W], tokens: [B, H, W] ∈ 0..999
# x_hat = decoder(z_hat)
# loss = F.mse_loss(x_hat, x) # 就这一项!9.6 LFQ:Lookup-Free Quantization(MAGVIT-v2,Yu et al., ICLR 2024)
class LFQ(nn.Module):
""" Lookup-Free Quantization (MAGVIT-v2)
每维独立 sign quantize, 隐式 codebook = 2^d """
def __init__(self, dim: int, entropy_weight: float = 0.1):
super().__init__()
self.d = dim
self.K = 2 ** dim
self.entropy_weight = entropy_weight
def forward(self, z):
# z: [B, d, ...]
q = torch.sign(z)
# 防止 sign(0) = 0
q = torch.where(q == 0, torch.ones_like(q), q)
# STE
z_hat = z + (q - z).detach()
# Entropy regularization(防止某维总是同符号)
# p_+ = sigmoid(z), p_- = 1 - p_+
if self.training:
p = torch.sigmoid(z)
per_dim_entropy = -(p * torch.log(p + 1e-9)
+ (1 - p) * torch.log(1 - p + 1e-9))
entropy_loss = -per_dim_entropy.mean() # maximize entropy → minimize -H
else:
entropy_loss = z.new_tensor(0.0)
return z_hat, self.entropy_weight * entropy_loss9.7 Cosmos / OpenMagViT2 / 现代 video tokenizer
Tokenizer 出处 量化方式 用在哪 MAGVIT-v2 Google 2024 (ICLR) LFQ text-to-video 早期 demo OpenMagViT2 开源复现 2024 LFQ 公开 video tokenizer baseline Cosmos Tokenizer NVIDIA 2024 FSQ + 视频时空压缩 NVIDIA Cosmos world model VideoPoet tokenizer Google 2024 LFQ-style text-to-video §10 复杂度与资源对比
模型 latent 类型 训练参数 (encoder+decoder) 主要 loss Codebook collapse STE 依赖 VAE continuous Gaussian $\sim$10-100M recon + KL (closed form) N/A 否 (reparameterization) $\beta$-VAE continuous Gaussian 同 VAE recon + $\beta$·KL N/A 否 NVAE hierarchical continuous 80M-200M recon + multi-layer KL N/A 否 VQ-VAE discrete via codebook 50-200M recon + codebook + $\beta$·commitment 常发生 是 VQ-VAE-2 hierarchical discrete 100-500M 同 VQ-VAE × 2 层 同上 是 VQ-GAN discrete + adversarial 50-300M (+ D) recon + LPIPS + GAN + codebook + commitment 同上 是 dVAE categorical (logits) 50-200M recon + KL to uniform 较少(categorical 分布学习) 否(Gumbel-softmax 反传 soft) FSQ scalar quantize per dim 30-150M recon (+ perceptual) 几乎不发生 是(但极简) LFQ binary scalar quantize 30-150M recon (+ entropy reg) 几乎不发生 是 §11 与相关方法对比 / 在生态中的位置
11.1 VAE vs GAN vs Diffusion vs Flow / FM
模型 似然 训练稳定性 多样性 样本质量 inference 速度 VAE 有(ELBO) ✅ 稳 ✅ 好 ⚠️ 模糊 ✅ 1-step GAN 无 ❌ 难 ❌ mode collapse ✅ 锐利 ✅ 1-step Diffusion 近似(VLB) ✅ 稳 ✅ 好 ✅ SOTA ❌ 多 NFE Flow / FM 有(ODE) ✅ 稳 ✅ 好 ✅ 强 ⚠️ 数 NFE 11.2 Tokenizer 系列在大模型中的角色
Tokenizer Stage 1 Generative Stage 2 (prior)
──────────────── ──────────────────────────
VQ-GAN → 离散 token grid → Transformer AR (Parti, DALL·E 1, Cogview)
VQ-GAN → 离散 token grid → Masked Transformer (MaskGIT, Muse)
FSQ → 离散 token grid → Transformer AR (Cosmos, OpenMagViT2)
LFQ → binary token grid → AR / bit predictor (MAGVIT-v2, VideoPoet)
KL-VAE → 连续 latent map → Diffusion / Flow Matching (LDM, SD, SD3, FLUX)11.3 Reconstruction-Perception Tradeoff(高级题)
§12 25 高频面试题
L1必会题(任何 ML 工程岗都会问)
Q1.VAE 的 ELBO 是什么?写出公式。
Q2.Reparameterization trick 解决什么问题?
Q3.$\beta$-VAE 的 $\beta$ 控制什么?
Q4.什么是 posterior collapse?
Q5.VQ-VAE 的 codebook 是什么?怎么用?
Q6.VQ-VAE 三项 loss 各是什么?
Q7.什么是 Straight-Through Estimator (STE)?
z_q = z_e + (z_q_quantized - z_e).detach()Q8.VQ-GAN 比 VQ-VAE 多了什么?
Q9.FSQ 是什么?为什么不会 codebook collapse?
Q10.KL($\mathcal{N}(\mu, \sigma^2 I) \,\|\, \mathcal{N}(0, I)$) 写出闭式。
kl = 0.5 * (mu**2 + logvar.exp() - logvar - 1).sum()L2进阶题(research-oriented 岗位)
Q11.IWAE 比 ELBO 更紧吗?怎么用?
Q12.如何缓解 posterior collapse?至少列 4 种。
Q13.推 KL($\mathcal{N}(\mu, \sigma^2) \,\|\, \mathcal{N}(0, 1))$ 闭式。
Q14.VQ-VAE 里 sg 在 codebook vs commitment 项的作用。
Q15.STE 的梯度等价于哪种 surrogate?
Q16.EMA codebook 更新公式是什么?为何用?
Q17.PatchGAN 是什么?为什么在 VQ-GAN 用它?
Q18.LPIPS 是什么?相比 MSE 优势?
Q19.Gumbel-Max trick 怎么用来近似 categorical 采样?
Q20.MaskGIT 比 AR 快在哪?为什么质量也不差?
L3高级变体(顶级 lab / generative model 方向)
Q21.推导 FSQ 隐式 codebook 大小,以及为什么不需要 STE 包装额外 loss?
z_hat = z + (z.round() - z).detach() 一行解决Q22.VQ-VAE 的 codebook collapse 怎么诊断 + 缓解?
Q23.VAE / VQ-VAE / Diffusion / FM 各自在 LDM 系列里的角色?
Q24.NVAE 怎么训稳层次 VAE?关键 trick 是什么?
Q25.Reconstruction-perception tradeoff 是什么?对 VQ-GAN / SD VAE 有什么含义?
§A 附录:完整 from-scratch 代码骨架 + sanity check
VAE —— Gaussian encoder + 重参数化 + Bernoulli/Gaussian decoder + closed-form KLVectorQuantizer —— 基本 codebook + STEVectorQuantizerEMA —— 生产标准 EMA codebook + dead-code revival 钩子VQVAE —— end-to-end image VQ-VAEPatchDiscriminator + hinge_d_loss / hinge_g_loss —— VQ-GAN discriminatorgumbel_softmax_sample —— Concrete / dVAE 用的 categorical 可导采样FSQ —— 10 行 finite scalar quantizationLFQ —— binary scalar quantization(MAGVIT-v2)[a] VAE(MNIST 784→20): recon=78.4 KL=18.6 loss=97.0 ✓
[b] reparam grad path: dL/dμ ≠ 0, dL/dlogvar ≠ 0 ✓
[c] VQ-VAE(64×64×3): recon=0.012 vq=0.034 perp=412/512 ✓
[d] EMA codebook usage: perp=478/512 (94%) after 10k steps ✓
[e] STE grad equiv: dL/dz_e == dL/dz_q (within fp) ✓
[f] FSQ(L=(8,5,5,5)): K_implicit=1000, usage=100% ✓
[g] FSQ grad path: round STE works, no codebook loss ✓
[h] LFQ(d=18): K_implicit=2^18=262144 ✓
[i] Gumbel-ST one-hot: forward hard, backward soft ✓torch.distributions.Normal.kl_divergence(...) 的闭式 KL diff = 0lucidrains/vector-quantize-pytorch 公开实现接口一致