Skip to content

生成模型:GAN、VAE、Diffusion

生成模型概述

生成模型分类:
├── 基于似然 (Likelihood-based)
│   ├── VAE (Variational Autoencoder)
│   └── Flow-based Models

├── 自回归模型 (Autoregressive)
│   └── PixelCNN, WaveNet, GPT

└── 隐式生成 (Implicit)
    ├── GAN (Generative Adversarial Network)
    └── Diffusion (扩散模型)

VAE (Variational Autoencoder)

核心思想

编码器:x → z_mean, z_logvar

采样:z = mean + std * epsilon

解码器:z → x̂

目标:x̂ 尽可能接近 x,同时 z 接近标准正态分布

数学推导

python
# 损失函数 = 重构损失 + KL 散度
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        self.fc_mean = nn.Linear(128, latent_dim)
        self.fc_logvar = nn.Linear(128, latent_dim)
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim)
        )
    
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std
    
    def forward(self, x):
        # 编码
        h = self.encoder(x)
        mean = self.fc_mean(h)
        logvar = self.fc_logvar(h)
        
        # 重参数化采样
        z = self.reparameterize(mean, logvar)
        
        # 解码
        x_recon = self.decoder(z)
        
        return x_recon, mean, logvar
    
    def loss_function(self, x, x_recon, mean, logvar):
        # 重构损失 (MSE 或 BCE)
        recon_loss = F.mse_loss(x_recon, x, reduction='sum')
        
        # KL 散度: KL(N(μ,σ) || N(0,1))
        kl_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
        
        return recon_loss + kl_loss

VAE 特点

  • 显式似然函数
  • 可学习的潜在空间
  • 生成平滑(插值效果好)
  • 但生成图片通常模糊(逐像素损失)

GAN (Generative Adversarial Network)

核心思想

对抗训练:
- Generator (生成器): 噪声 z → 假图片 G(z)
- Discriminator (判别器): 图片 x → 真实/假 (0/1)
- 目标:G 骗过 D,D 识别 G

博弈论:零和博弈
min_G max_D L(D, G) = E[log(D(x))] + E[log(1-D(G(z)))]

经典 GAN

python
class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, features_g):
        super().__init__()
        self.net = nn.Sequential(
            # 输入: z (latent_dim, 1, 1)
            nn.ConvTranspose2d(latent_dim, features_g*16, 4, 1, 0),
            nn.BatchNorm2d(features_g*16),
            nn.ReLU(),
            # 128*4*4
            nn.ConvTranspose2d(features_g*16, features_g*8, 4, 2, 1),
            nn.BatchNorm2d(features_g*8),
            nn.ReLU(),
            # 256*8*8
            nn.ConvTranspose2d(features_g*8, features_g*4, 4, 2, 1),
            nn.BatchNorm2d(features_g*4),
            nn.ReLU(),
            # 512*16*16
            nn.ConvTranspose2d(features_g*4, features_g*2, 4, 2, 1),
            nn.BatchNorm2d(features_g*2),
            nn.ReLU(),
            # 1024*32*32
            nn.ConvTranspose2d(features_g*2, img_channels, 4, 2, 1),
            nn.Tanh()  # 输出 [-1, 1]
        )
    
    def forward(self, z):
        return self.net(z)


class Discriminator(nn.Module):
    def __init__(self, img_channels, features_d):
        super().__init__()
        self.net = nn.Sequential(
            # 输入: img (C*64*64)
            nn.Conv2d(img_channels, features_d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            # 128*32*32
            nn.Conv2d(features_d, features_d*2, 4, 2, 1),
            nn.BatchNorm2d(features_d*2),
            nn.LeakyReLU(0.2),
            ...
            nn.Conv2d(features_d*8, 1, 4, 1, 0),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x).view(-1, 1).squeeze(1)

训练循环

python
def train_gan(G, D, dataloader, epochs, latent_dim, device):
    criterion = nn.BCELoss()
    opt_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    for epoch in range(epochs):
        for real_images, _ in dataloader:
            batch_size = real_images.size(0)
            real_images = real_images.to(device)
            
            real_labels = torch.ones(batch_size).to(device)
            fake_labels = torch.zeros(batch_size).to(device)
            
            # 训练判别器
            D.zero_grad()
            real_output = D(real_images)
            real_loss = criterion(real_output, real_labels)
            
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_images = G(z)
            fake_output = D(fake_images.detach())
            fake_loss = criterion(fake_output, fake_labels)
            
            d_loss = real_loss + fake_loss
            d_loss.backward()
            opt_D.step()
            
            # 训练生成器
            G.zero_grad()
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_images = G(z)
            fake_output = D(fake_images)
            g_loss = criterion(fake_output, real_labels)
            g_loss.backward()
            opt_G.step()

GAN 训练问题

问题表现解决方案
Mode Collapse生成多样性差Mini-batch Discrimination, Unrolled GAN
训练不稳定震荡、梯度爆炸标签平滑 (Label Smoothing),TTUR
梯度消失D 太强替代损失 (WGAN),谱归一化
非收敛震荡单独更新 G/D,用学习率调度

GAN 变体

模型改进
DCGAN深度卷积 + BatchNorm
WGANWasserstein 距离替代 JS 散度
WGAN-GP梯度惩罚稳定训练
CGAN条件生成
StyleGAN渐进式增长 + 风格迁移
BigGAN大 Batch + 类别条件
SAGANSelf-Attention
BigBiGAN添加编码器

Diffusion Model (扩散模型)

前向过程 (加噪)

x₀ (真实图像) → x₁ → x₂ → ... → xₜ → ... → x_T (纯噪声)

q(xₜ|xₜ₋₁) = N(xₜ; √(1-βₜ)xₜ₋₁, βₜI)

β₁, β₂, ..., β_T: 噪声调度 (通常 β₁=10⁻⁴, β_T=0.02)
python
def forward_diffusion(x₀, t, betas):
    """给图像加噪"""
    noise = torch.randn_like(x₀)
    alphas = 1 - betas
    alpha_bar = torch.cumprod(alphas, dim=0)
    
    sqrt_alpha_bar = torch.sqrt(alpha_bar[t])
    sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bar[t])
    
    xₜ = sqrt_alpha_bar * x₀ + sqrt_one_minus_alpha_bar * noise
    return xₜ, noise

反向过程 (去噪)

x_T (纯噪声) → x_{T-1} → ... → x₁ → x₀ (生成图像)

p_θ(x_{t-1}|xₜ) = N(μ_θ(xₜ,t), Σ_θ(xₜ,t))

神经网络预测噪声 ε_θ(xₜ, t)
python
class UNet(nn.Module):
    """U-Net 预测噪声"""
    def forward(self, xₜ, t):
        # xₜ: 噪声图像
        # t: 时间步
        # 输出: 预测的噪声 ε_θ
        ...


def p_sample(model, xₜ, t, betas):
    """从 xₜ 采样 x_{t-1}"""
    t = t.view(-1)
    
    # 预测噪声
    eps = model(xₜ, t)
    
    # 计算均值
    alpha = 1 - betas[t]
    alpha_bar = torch.cumprod(alpha, dim=0)
    
    mean = (1/√alpha[t]) * (xₜ - betas[t]/√(1-alpha_bar[t]) * eps)
    
    # 添加噪声(除了最后一步)
    if t > 0:
        noise = torch.randn_like(xₜ)
        sigma = torch.sqrt(betas[t])
        x_{t-1} = mean + sigma * noise
    else:
        x_{t-1} = mean
    
    return x_{t-1}


def reverse_process(model, x_T, betas):
    """完整的反向过程"""
    xₜ = x_T
    for t in reversed(range(len(betas))):
        xₜ = p_sample(model, xₜ, torch.tensor([t]))
    return xₜ  # x₀

训练目标

python
def diffusion_loss(model, x₀, betas):
    """DDPM 训练目标:预测噪声"""
    batch_size = x₀.size(0)
    
    # 随机采样时间步
    t = torch.randint(0, len(betas), (batch_size,)).to(x₀.device)
    
    # 加噪
    xₜ, noise = forward_diffusion(x₀, t, betas)
    
    # 预测噪声
    pred_noise = model(xₜ, t)
    
    # MSE 损失
    loss = F.mse_loss(pred_noise, noise)
    
    return loss

调度 (Schedule)

python
# Linear Schedule (DDPM)
betas = torch.linspace(beta_start, beta_end, T)

# Cosine Schedule (改进版)
def cosine_beta_schedule(T, s=0.008):
    steps = T + 1
    x = torch.linspace(0, T, steps)
    alpha_bar = torch.cos((x/T + s)/(1+s) * π/2
    betas = 1 - alpha_bar[1:]/alpha_bar[:-1]
    return betas.clamp(0.0001, 0.02)

# SD使用: PLMS, DDIM 采样加速

Classifier-Free Guidance

python
# 无分类器引导:结合条件和无条件生成
def cfg_sample(model, xₜ, t, cond, guidance_scale=7.5):
    # 条件预测
    eps_cond = model(xₜ, t, cond)
    
    # 无条件预测 (cond=None)
    eps_uncond = model(xₜ, t, None)
    
    # 引导
    eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
    
    return eps

主流 Diffusion 模型

模型特点用途
DDPM基础扩散模型理论
DDIM加速采样训练相同,采样快
Stable DiffusionLatent Diffusion开源文生图
DALL-E 2CLIP 引导闭源文生图
Imagen超分辨率级联Google
SDXL更大模型高质量
SoraVideo Diffusion视频生成

Latent Diffusion (Stable Diffusion)

python
# 为什么不直接在像素空间扩散?
# 512x512x3 = 786,432 维
# VAE 压缩后: 64x64x4 = 16,384 维 (48倍压缩!)

class LatentDiffusion:
    def __init__(self):
        self.vae = load_vae()      # 图像 ↔ Latent
        self.text_encoder = load_clip()
        self.unet = UNet()          # Latent 空间扩散
        self.scheduler = DDIMScheduler()
    
    @torch.no_grad()
    def generate(self, prompt, num_steps=50):
        # 文本编码
        text_emb = self.text_encoder(prompt)
        
        # 随机起始 Latent
        latents = torch.randn(1, 4, 64, 64)
        
        # 迭代去噪
        for t in reversed(range(num_steps)):
            # 预测噪声
            noise_pred = self.unet(latents, t, text_emb)
            
            # DDIM 采样
            latents = self.scheduler.step(noise_pred, t, latents)
        
        # 解码回图像
        images = self.vae.decode(latents)
        return images

生成模型对比

特性VAEGANDiffusion
训练稳定性
模式覆盖
生成质量模糊清晰清晰
采样速度慢(但 DDIM 加速)
理论基础似然对抗分数匹配
潜空间
典型用途插值、可解释高质量生成文生图/视频

评估指标

python
# 图像生成
from fid_score import calculate_fid
fid = calculate_fid(real_images, generated_images)

# Inception Score
from inception_score import inception_score
is_score, _ = inception_score(generated_images)

# 其他
# - LPIPS (感知质量)
# - CLIP Score (与文本一致性)
# - Precision/Recall (模式覆盖)

面试要点

1. GAN vs Diffusion
   - GAN: 单步生成,快但训练不稳定
   - Diffusion: 多步生成,慢但训练稳定、质量好

2. Diffusion 为什么现在火
   - 训练稳定(逐步加噪去噪)
   - 模式覆盖好(不会 mode collapse)
   - 质量高(特别是配合 CLIP)

3. DDPM 训练目标
   - 预测噪声 ε_θ(xₜ, t)
   - 简单 MSE 损失

4. Classifier-Free Guidance 原理
   - 结合条件和无条件预测
   - 增强条件信号

5. VAE 为什么模糊
   - 逐像素 MSE 损失
   - 潜空间不完美
   - 解决:对抗训练、感知损失

6. GAN 训练技巧
   - 标签平滑
   - TTUR (不同学习率)
   - 谱归一化
   - 特征匹配

基于 MIT 许可发布