生成模型:GAN、VAE、Diffusion
生成模型概述
生成模型分类:
├── 基于似然 (Likelihood-based)
│ ├── VAE (Variational Autoencoder)
│ └── Flow-based Models
│
├── 自回归模型 (Autoregressive)
│ └── PixelCNN, WaveNet, GPT
│
└── 隐式生成 (Implicit)
├── GAN (Generative Adversarial Network)
└── Diffusion (扩散模型)VAE (Variational Autoencoder)
核心思想
编码器:x → z_mean, z_logvar
↓
采样:z = mean + std * epsilon
↓
解码器:z → x̂
目标:x̂ 尽可能接近 x,同时 z 接近标准正态分布数学推导
python
# 损失函数 = 重构损失 + KL 散度
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
self.fc_mean = nn.Linear(128, latent_dim)
self.fc_logvar = nn.Linear(128, latent_dim)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, input_dim)
)
def reparameterize(self, mean, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mean + eps * std
def forward(self, x):
# 编码
h = self.encoder(x)
mean = self.fc_mean(h)
logvar = self.fc_logvar(h)
# 重参数化采样
z = self.reparameterize(mean, logvar)
# 解码
x_recon = self.decoder(z)
return x_recon, mean, logvar
def loss_function(self, x, x_recon, mean, logvar):
# 重构损失 (MSE 或 BCE)
recon_loss = F.mse_loss(x_recon, x, reduction='sum')
# KL 散度: KL(N(μ,σ) || N(0,1))
kl_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
return recon_loss + kl_lossVAE 特点
- 显式似然函数
- 可学习的潜在空间
- 生成平滑(插值效果好)
- 但生成图片通常模糊(逐像素损失)
GAN (Generative Adversarial Network)
核心思想
对抗训练:
- Generator (生成器): 噪声 z → 假图片 G(z)
- Discriminator (判别器): 图片 x → 真实/假 (0/1)
- 目标:G 骗过 D,D 识别 G
博弈论:零和博弈
min_G max_D L(D, G) = E[log(D(x))] + E[log(1-D(G(z)))]经典 GAN
python
class Generator(nn.Module):
def __init__(self, latent_dim, img_channels, features_g):
super().__init__()
self.net = nn.Sequential(
# 输入: z (latent_dim, 1, 1)
nn.ConvTranspose2d(latent_dim, features_g*16, 4, 1, 0),
nn.BatchNorm2d(features_g*16),
nn.ReLU(),
# 128*4*4
nn.ConvTranspose2d(features_g*16, features_g*8, 4, 2, 1),
nn.BatchNorm2d(features_g*8),
nn.ReLU(),
# 256*8*8
nn.ConvTranspose2d(features_g*8, features_g*4, 4, 2, 1),
nn.BatchNorm2d(features_g*4),
nn.ReLU(),
# 512*16*16
nn.ConvTranspose2d(features_g*4, features_g*2, 4, 2, 1),
nn.BatchNorm2d(features_g*2),
nn.ReLU(),
# 1024*32*32
nn.ConvTranspose2d(features_g*2, img_channels, 4, 2, 1),
nn.Tanh() # 输出 [-1, 1]
)
def forward(self, z):
return self.net(z)
class Discriminator(nn.Module):
def __init__(self, img_channels, features_d):
super().__init__()
self.net = nn.Sequential(
# 输入: img (C*64*64)
nn.Conv2d(img_channels, features_d, 4, 2, 1),
nn.LeakyReLU(0.2),
# 128*32*32
nn.Conv2d(features_d, features_d*2, 4, 2, 1),
nn.BatchNorm2d(features_d*2),
nn.LeakyReLU(0.2),
...
nn.Conv2d(features_d*8, 1, 4, 1, 0),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x).view(-1, 1).squeeze(1)训练循环
python
def train_gan(G, D, dataloader, epochs, latent_dim, device):
criterion = nn.BCELoss()
opt_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
for epoch in range(epochs):
for real_images, _ in dataloader:
batch_size = real_images.size(0)
real_images = real_images.to(device)
real_labels = torch.ones(batch_size).to(device)
fake_labels = torch.zeros(batch_size).to(device)
# 训练判别器
D.zero_grad()
real_output = D(real_images)
real_loss = criterion(real_output, real_labels)
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = G(z)
fake_output = D(fake_images.detach())
fake_loss = criterion(fake_output, fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
opt_D.step()
# 训练生成器
G.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
fake_images = G(z)
fake_output = D(fake_images)
g_loss = criterion(fake_output, real_labels)
g_loss.backward()
opt_G.step()GAN 训练问题
| 问题 | 表现 | 解决方案 |
|---|---|---|
| Mode Collapse | 生成多样性差 | Mini-batch Discrimination, Unrolled GAN |
| 训练不稳定 | 震荡、梯度爆炸 | 标签平滑 (Label Smoothing),TTUR |
| 梯度消失 | D 太强 | 替代损失 (WGAN),谱归一化 |
| 非收敛 | 震荡 | 单独更新 G/D,用学习率调度 |
GAN 变体
| 模型 | 改进 |
|---|---|
| DCGAN | 深度卷积 + BatchNorm |
| WGAN | Wasserstein 距离替代 JS 散度 |
| WGAN-GP | 梯度惩罚稳定训练 |
| CGAN | 条件生成 |
| StyleGAN | 渐进式增长 + 风格迁移 |
| BigGAN | 大 Batch + 类别条件 |
| SAGAN | Self-Attention |
| BigBiGAN | 添加编码器 |
Diffusion Model (扩散模型)
前向过程 (加噪)
x₀ (真实图像) → x₁ → x₂ → ... → xₜ → ... → x_T (纯噪声)
q(xₜ|xₜ₋₁) = N(xₜ; √(1-βₜ)xₜ₋₁, βₜI)
β₁, β₂, ..., β_T: 噪声调度 (通常 β₁=10⁻⁴, β_T=0.02)python
def forward_diffusion(x₀, t, betas):
"""给图像加噪"""
noise = torch.randn_like(x₀)
alphas = 1 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
sqrt_alpha_bar = torch.sqrt(alpha_bar[t])
sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bar[t])
xₜ = sqrt_alpha_bar * x₀ + sqrt_one_minus_alpha_bar * noise
return xₜ, noise反向过程 (去噪)
x_T (纯噪声) → x_{T-1} → ... → x₁ → x₀ (生成图像)
p_θ(x_{t-1}|xₜ) = N(μ_θ(xₜ,t), Σ_θ(xₜ,t))
神经网络预测噪声 ε_θ(xₜ, t)python
class UNet(nn.Module):
"""U-Net 预测噪声"""
def forward(self, xₜ, t):
# xₜ: 噪声图像
# t: 时间步
# 输出: 预测的噪声 ε_θ
...
def p_sample(model, xₜ, t, betas):
"""从 xₜ 采样 x_{t-1}"""
t = t.view(-1)
# 预测噪声
eps = model(xₜ, t)
# 计算均值
alpha = 1 - betas[t]
alpha_bar = torch.cumprod(alpha, dim=0)
mean = (1/√alpha[t]) * (xₜ - betas[t]/√(1-alpha_bar[t]) * eps)
# 添加噪声(除了最后一步)
if t > 0:
noise = torch.randn_like(xₜ)
sigma = torch.sqrt(betas[t])
x_{t-1} = mean + sigma * noise
else:
x_{t-1} = mean
return x_{t-1}
def reverse_process(model, x_T, betas):
"""完整的反向过程"""
xₜ = x_T
for t in reversed(range(len(betas))):
xₜ = p_sample(model, xₜ, torch.tensor([t]))
return xₜ # x₀训练目标
python
def diffusion_loss(model, x₀, betas):
"""DDPM 训练目标:预测噪声"""
batch_size = x₀.size(0)
# 随机采样时间步
t = torch.randint(0, len(betas), (batch_size,)).to(x₀.device)
# 加噪
xₜ, noise = forward_diffusion(x₀, t, betas)
# 预测噪声
pred_noise = model(xₜ, t)
# MSE 损失
loss = F.mse_loss(pred_noise, noise)
return loss调度 (Schedule)
python
# Linear Schedule (DDPM)
betas = torch.linspace(beta_start, beta_end, T)
# Cosine Schedule (改进版)
def cosine_beta_schedule(T, s=0.008):
steps = T + 1
x = torch.linspace(0, T, steps)
alpha_bar = torch.cos((x/T + s)/(1+s) * π/2)²
betas = 1 - alpha_bar[1:]/alpha_bar[:-1]
return betas.clamp(0.0001, 0.02)
# SD使用: PLMS, DDIM 采样加速Classifier-Free Guidance
python
# 无分类器引导:结合条件和无条件生成
def cfg_sample(model, xₜ, t, cond, guidance_scale=7.5):
# 条件预测
eps_cond = model(xₜ, t, cond)
# 无条件预测 (cond=None)
eps_uncond = model(xₜ, t, None)
# 引导
eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
return eps主流 Diffusion 模型
| 模型 | 特点 | 用途 |
|---|---|---|
| DDPM | 基础扩散模型 | 理论 |
| DDIM | 加速采样 | 训练相同,采样快 |
| Stable Diffusion | Latent Diffusion | 开源文生图 |
| DALL-E 2 | CLIP 引导 | 闭源文生图 |
| Imagen | 超分辨率级联 | |
| SDXL | 更大模型 | 高质量 |
| Sora | Video Diffusion | 视频生成 |
Latent Diffusion (Stable Diffusion)
python
# 为什么不直接在像素空间扩散?
# 512x512x3 = 786,432 维
# VAE 压缩后: 64x64x4 = 16,384 维 (48倍压缩!)
class LatentDiffusion:
def __init__(self):
self.vae = load_vae() # 图像 ↔ Latent
self.text_encoder = load_clip()
self.unet = UNet() # Latent 空间扩散
self.scheduler = DDIMScheduler()
@torch.no_grad()
def generate(self, prompt, num_steps=50):
# 文本编码
text_emb = self.text_encoder(prompt)
# 随机起始 Latent
latents = torch.randn(1, 4, 64, 64)
# 迭代去噪
for t in reversed(range(num_steps)):
# 预测噪声
noise_pred = self.unet(latents, t, text_emb)
# DDIM 采样
latents = self.scheduler.step(noise_pred, t, latents)
# 解码回图像
images = self.vae.decode(latents)
return images生成模型对比
| 特性 | VAE | GAN | Diffusion |
|---|---|---|---|
| 训练稳定性 | 好 | 差 | 好 |
| 模式覆盖 | 好 | 差 | 好 |
| 生成质量 | 模糊 | 清晰 | 清晰 |
| 采样速度 | 快 | 快 | 慢(但 DDIM 加速) |
| 理论基础 | 似然 | 对抗 | 分数匹配 |
| 潜空间 | 有 | 无 | 无 |
| 典型用途 | 插值、可解释 | 高质量生成 | 文生图/视频 |
评估指标
python
# 图像生成
from fid_score import calculate_fid
fid = calculate_fid(real_images, generated_images)
# Inception Score
from inception_score import inception_score
is_score, _ = inception_score(generated_images)
# 其他
# - LPIPS (感知质量)
# - CLIP Score (与文本一致性)
# - Precision/Recall (模式覆盖)面试要点
1. GAN vs Diffusion
- GAN: 单步生成,快但训练不稳定
- Diffusion: 多步生成,慢但训练稳定、质量好
2. Diffusion 为什么现在火
- 训练稳定(逐步加噪去噪)
- 模式覆盖好(不会 mode collapse)
- 质量高(特别是配合 CLIP)
3. DDPM 训练目标
- 预测噪声 ε_θ(xₜ, t)
- 简单 MSE 损失
4. Classifier-Free Guidance 原理
- 结合条件和无条件预测
- 增强条件信号
5. VAE 为什么模糊
- 逐像素 MSE 损失
- 潜空间不完美
- 解决:对抗训练、感知损失
6. GAN 训练技巧
- 标签平滑
- TTUR (不同学习率)
- 谱归一化
- 特征匹配