Skip to content

PyTorch 变分自编码器 (VAE) 实现详细指南

在本教程中,我们将详细介绍如何在 PyTorch 中实现变分自编码器(VAE)。内容将涵盖以下三个部分:

  1. VAE 网络层搭建方法
  2. 前向传播定义方法
  3. 模型训练方法

每个部分后都会有相应的**习题**,帮助你通过练习加深理解。


1. VAE 网络层搭建方法

变分自编码器 (VAE) 由编码器(Encoder)和解码器(Decoder)组成。编码器将输入数据映射到潜在空间的分布参数(均值和对数方差),然后通过重参数化技巧采样潜在变量。解码器根据这些潜在变量重构输入数据。

1.1 VAE 的基本结构

核心组件:

  • 编码器(Encoder):将输入数据\(x\)映射到潜在空间分布的参数\(\mu\)\(\log \sigma^2\)
  • 重参数化层(Reparameterization Trick):从编码器输出的分布中采样潜在变量\(z\)
  • 解码器(Decoder):根据潜在变量\(z\)生成重构数据\(\hat{x}\)
  • 损失函数:包括重构损失和KL散度损失。

1.2 编码器和解码器的实现

首先,我们需要定义编码器和解码器的网络结构。假设输入是28x28的MNIST图像,展平后为784维向量,潜在空间维度为32。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_dim=784, h1_dim=512, h2_dim=256, latent_dim=32):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, h1_dim)
        self.fc2 = nn.Linear(h1_dim, h2_dim)
        self.fc_mu = nn.Linear(h2_dim, latent_dim)
        self.fc_logvar = nn.Linear(h2_dim, latent_dim)

    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        mu = self.fc_mu(h2)
        logvar = self.fc_logvar(h2)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim=32, h2_dim=256, h1_dim=512, output_dim=784):
        super(Decoder, self).__init__()
        self.fc3 = nn.Linear(latent_dim, h2_dim)
        self.fc4 = nn.Linear(h2_dim, h1_dim)
        self.fc5 = nn.Linear(h1_dim, output_dim)

    def forward(self, z):
        D1 = F.relu(self.fc3(z))
        D2 = F.relu(self.fc4(D1))
        x_hat = torch.sigmoid(self.fc5(D2))
        return x_hat

1.3 VAE 整体结构

将编码器和解码器组合成一个完整的VAE模型,并实现重参数化技巧。

class VAE(nn.Module):
    def __init__(self, input_dim=784, h1_dim=512, h2_dim=256, latent_dim=32):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, h1_dim, h2_dim, latent_dim)
        self.decoder = Decoder(latent_dim, h2_dim, h1_dim, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # 计算标准差
        eps = torch.randn_like(std)    # 采样ε
        return mu + std * eps            # 重参数化

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decoder(z)
        return x_hat, mu, logvar

2. 前向传播定义方法

在前向传播过程中,VAE 将输入数据通过编码器得到潜在分布参数,然后通过重参数化采样潜在变量,最后通过解码器生成重构数据。

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim=784, h1_dim=512, h2_dim=256, latent_dim=32):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, h1_dim, h2_dim, latent_dim)
        self.decoder = Decoder(latent_dim, h2_dim, h1_dim, input_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # 标准差
        eps = torch.randn_like(std)    # ε ~ N(0,1)
        return mu + std * eps            # z = μ + σ * ε

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decoder(z)
        return x_hat, mu, logvar

说明:

  • 输入形状:假设输入x的形状为(batch_size, input_dim)
  • 编码器输出mulogvar分别表示潜在分布的均值和对数方差。
  • 重参数化:通过z = mu + sigma * eps实现可微分的采样过程。
  • 解码器输出x_hat是重构后的数据,使用sigmoid激活确保输出在[0,1]范围内(适用于图像数据)。

3. 模型训练方法

训练VAE的流程包括定义损失函数(重构损失和KL散度损失)、选择优化器、执行前向传播、计算损失、反向传播和参数更新。

3.1 损失函数

VAE的总损失函数由两部分组成:

  1. 重构损失 (\(\mathcal{L}_{recon}\)): 衡量重构数据与原始数据的相似度。对于二值数据(如MNIST),常使用二元交叉熵损失;对于连续数据,可以使用均方误差损失。

  2. KL散度损失 (\(\mathcal{L}_{KL}\)): 衡量编码器输出的潜在分布与先验分布(标准正态分布)的差异。

公式:

\[ \mathcal{L}_{total} = \mathcal{L}_{recon} + \beta \mathcal{L}_{KL} \]

其中,\(\beta\) 是权重系数,用于平衡两部分损失。

实现:

def loss_function(x, x_hat, mu, logvar, beta=1.0):
    # 重构损失 - 二元交叉熵
    BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')

    # KL散度损失
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + beta * KLD

说明:

  • reduction='sum' 将损失求和,有助于梯度的稳定性。
  • KL散度的计算基于多维情形,假设潜在维度独立。

3.2 训练循环示例

以下是一个完整的训练循环示例,包括数据加载、模型定义、损失计算和优化步骤。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

# 超参数定义
input_dim = 784        # 28x28图像
h1_dim = 512
h2_dim = 256
latent_dim = 32
batch_size = 128
num_epochs = 20
learning_rate = 1e-3
beta = 1.0             # KL散度权重

# 数据准备 - 使用MNIST数据集
transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 模型、优化器定义
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim, h1_dim, h2_dim, latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练过程
model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, input_dim).to(device)  # 展平
        optimizer.zero_grad()
        x_hat, mu, logvar = model(data)
        loss = loss_function(data, x_hat, mu, logvar, beta)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_loss = train_loss / len(train_loader.dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

# 测试过程 - 生成样本
model.eval()
with torch.no_grad():
    z = torch.randn(64, latent_dim).to(device)
    sample = model.decoder(z).cpu()
    sample = sample.view(64, 1, 28, 28)
    # 这里可以使用torchvision.utils.save_image保存生成的样本

说明:

  • 数据准备:使用torchvision加载MNIST数据集,并进行标准的张量转换。
  • 模型训练
  • 前向传播:通过VAE生成重构数据。
  • 损失计算:计算总损失(重构 + KL散度)。
  • 反向传播:计算梯度并更新模型参数。
  • 测试过程:通过随机采样潜在变量z,使用解码器生成新的样本。

3.3 防止梯度消失和爆炸

在训练VAE时,梯度消失和爆炸可能导致训练不稳定。常用的方法包括:

  • 梯度裁剪:限制梯度的最大范数,防止梯度爆炸。
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

loss.backward()之后、optimizer.step()之前添加梯度裁剪。

  • 使用改进的优化器:如Adam,这些优化器在处理梯度时更加稳定。

  • 权重初始化:合理初始化模型权重,以避免初始阶段的梯度问题。

3.4 正则化方法

正则化有助于提高模型的泛化能力,防止过拟合。

  • Dropout:在编码器和解码器的隐藏层之间应用Dropout。
self.fc1 = nn.Linear(input_dim, h1_dim)
self.dropout1 = nn.Dropout(p=0.5)

forward方法中:

h1 = F.relu(self.dropout1(self.fc1(x)))
  • L2 正则化:通过在优化器中添加权重衰减参数实现。
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

习题与解答

习题 1:构建一个带有Batch Normalization的VAE

  1. 在编码器和解码器的隐藏层之后添加Batch Normalization层。

    提示: 使用nn.BatchNorm1d在每个隐藏层之后添加批归一化。

    查看答案

    参考答案:

    import torch.nn as nn
    import torch.nn.functional as F
    
    class EncoderBN(nn.Module):
        def __init__(self, input_dim=784, h1_dim=512, h2_dim=256, latent_dim=32):
            super(EncoderBN, self).__init__()
            self.fc1 = nn.Linear(input_dim, h1_dim)
            self.bn1 = nn.BatchNorm1d(h1_dim)
            self.fc2 = nn.Linear(h1_dim, h2_dim)
            self.bn2 = nn.BatchNorm1d(h2_dim)
            self.fc_mu = nn.Linear(h2_dim, latent_dim)
            self.fc_logvar = nn.Linear(h2_dim, latent_dim)
    
        def forward(self, x):
            h1 = F.relu(self.bn1(self.fc1(x)))
            h2 = F.relu(self.bn2(self.fc2(h1)))
            mu = self.fc_mu(h2)
            logvar = self.fc_logvar(h2)
            return mu, logvar
    
    class DecoderBN(nn.Module):
        def __init__(self, latent_dim=32, h2_dim=256, h1_dim=512, output_dim=784):
            super(DecoderBN, self).__init__()
            self.fc3 = nn.Linear(latent_dim, h2_dim)
            self.bn3 = nn.BatchNorm1d(h2_dim)
            self.fc4 = nn.Linear(h2_dim, h1_dim)
            self.bn4 = nn.BatchNorm1d(h1_dim)
            self.fc5 = nn.Linear(h1_dim, output_dim)
    
        def forward(self, z):
            D1 = F.relu(self.bn3(self.fc3(z)))
            D2 = F.relu(self.bn4(self.fc4(D1)))
            x_hat = torch.sigmoid(self.fc5(D2))
            return x_hat
    
    class VAEBN(nn.Module):
        def __init__(self, input_dim=784, h1_dim=512, h2_dim=256, latent_dim=32):
            super(VAEBN, self).__init__()
            self.encoder = EncoderBN(input_dim, h1_dim, h2_dim, latent_dim)
            self.decoder = DecoderBN(latent_dim, h2_dim, h1_dim, input_dim)
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + std * eps
    
        def forward(self, x):
            mu, logvar = self.encoder(x)
            z = self.reparameterize(mu, logvar)
            x_hat = self.decoder(z)
            return x_hat, mu, logvar
    
  2. 修改上述VAE模型,在重参数化层中加入Dropout层,Dropout概率为0.2。

    提示: 在重参数化过程中加入Dropout。

    查看答案

    参考答案:

    import torch.nn as nn
    import torch.nn.functional as F
    
    class VAEWithDropout(nn.Module):
        def __init__(self, input_dim=784, h1_dim=512, h2_dim=256, latent_dim=32, dropout_p=0.2):
            super(VAEWithDropout, self).__init__()
            self.encoder = Encoder(input_dim, h1_dim, h2_dim, latent_dim)
            self.decoder = Decoder(latent_dim, h2_dim, h1_dim, input_dim)
            self.dropout = nn.Dropout(p=dropout_p)
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + std * eps
    
        def forward(self, x):
            mu, logvar = self.encoder(x)
            z = self.dropout(self.reparameterize(mu, logvar))  # 在采样后应用Dropout
            x_hat = self.decoder(z)
            return x_hat, mu, logvar
    
  3. 实现一个条件VAE(CVAE),即在编码器和解码器中加入类别标签作为输入。假设有10个类别。

    提示: 将类别标签嵌入为向量,并与输入数据拼接后输入编码器和解码器。

    查看答案

    参考答案:

    import torch.nn as nn
    import torch.nn.functional as F
    
    class EncoderCVAE(nn.Module):
        def __init__(self, input_dim=784, label_dim=10, h1_dim=512, h2_dim=256, latent_dim=32):
            super(EncoderCVAE, self).__init__()
            self.fc1 = nn.Linear(input_dim + label_dim, h1_dim)
            self.fc2 = nn.Linear(h1_dim, h2_dim)
            self.fc_mu = nn.Linear(h2_dim, latent_dim)
            self.fc_logvar = nn.Linear(h2_dim, latent_dim)
    
        def forward(self, x, labels):
            # One-hot编码标签
            labels_onehot = F.one_hot(labels, num_classes=10).float()
            x = torch.cat([x, labels_onehot], dim=1)
            h1 = F.relu(self.fc1(x))
            h2 = F.relu(self.fc2(h1))
            mu = self.fc_mu(h2)
            logvar = self.fc_logvar(h2)
            return mu, logvar
    
    class DecoderCVAE(nn.Module):
        def __init__(self, latent_dim=32, label_dim=10, h2_dim=256, h1_dim=512, output_dim=784):
            super(DecoderCVAE, self).__init__()
            self.fc3 = nn.Linear(latent_dim + label_dim, h2_dim)
            self.fc4 = nn.Linear(h2_dim, h1_dim)
            self.fc5 = nn.Linear(h1_dim, output_dim)
    
        def forward(self, z, labels):
            labels_onehot = F.one_hot(labels, num_classes=10).float()
            z = torch.cat([z, labels_onehot], dim=1)
            D1 = F.relu(self.fc3(z))
            D2 = F.relu(self.fc4(D1))
            x_hat = torch.sigmoid(self.fc5(D2))
            return x_hat
    
    class CVAE(nn.Module):
        def __init__(self, input_dim=784, label_dim=10, h1_dim=512, h2_dim=256, latent_dim=32):
            super(CVAE, self).__init__()
            self.encoder = EncoderCVAE(input_dim, label_dim, h1_dim, h2_dim, latent_dim)
            self.decoder = DecoderCVAE(latent_dim, label_dim, h2_dim, h1_dim, input_dim)
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + std * eps
    
        def forward(self, x, labels):
            mu, logvar = self.encoder(x, labels)
            z = self.reparameterize(mu, logvar)
            x_hat = self.decoder(z, labels)
            return x_hat, mu, logvar
    

习题 2:实现一个不同潜在空间维度的VAE并比较其结果

  1. 构建两个VAE模型,一个潜在空间维度为16,另一个为64。训练它们并比较重构效果和生成样本的质量。

    提示: 调整latent_dim参数,训练后观察生成图像的差异。

    查看答案

    参考答案:

    # 定义两个VAE模型
    latent_dim1 = 16
    latent_dim2 = 64
    
    model1 = VAE(input_dim, h1_dim, h2_dim, latent_dim1).to(device)
    model2 = VAE(input_dim, h1_dim, h2_dim, latent_dim2).to(device)
    
    optimizer1 = optim.Adam(model1.parameters(), lr=learning_rate)
    optimizer2 = optim.Adam(model2.parameters(), lr=learning_rate)
    
    # 定义相同的训练过程
    def train_vae(model, optimizer, train_loader, num_epochs, beta):
        model.train()
        for epoch in range(num_epochs):
            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = data.view(-1, input_dim).to(device)
                optimizer.zero_grad()
                x_hat, mu, logvar = model(data)
                loss = loss_function(data, x_hat, mu, logvar, beta)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
    
            avg_loss = train_loss / len(train_loader.dataset)
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
    
    print("Training VAE with latent_dim=16")
    train_vae(model1, optimizer1, train_loader, num_epochs, beta)
    
    print("\nTraining VAE with latent_dim=64")
    train_vae(model2, optimizer2, train_loader, num_epochs, beta)
    
    # 比较生成样本的质量
    model1.eval()
    model2.eval()
    with torch.no_grad():
        z1 = torch.randn(64, latent_dim1).to(device)
        z2 = torch.randn(64, latent_dim2).to(device)
        sample1 = model1.decoder(z1).cpu().view(64, 1, 28, 28)
        sample2 = model2.decoder(z2).cpu().view(64, 1, 28, 28)
        # 使用torchvision.utils.save_image保存或可视化sample1和sample2
    

    比较结果:

    • 潜在空间维度为16:生成的图像可能较为模糊,细节较少,但捕捉了主要的结构特征。
    • 潜在空间维度为64:生成的图像细节更丰富,质量更高,但潜在空间维度过大可能导致过拟合或潜在空间不明显。

    通过不同潜在空间维度的比较,可以观察到维度对模型能力和生成质量的影响。

  2. 修改VAE的编码器,使其使用LeakyReLU激活函数代替ReLU,并观察训练过程中的变化。

    提示:F.relu替换为F.leaky_relu,并设置适当的负斜率参数。

    查看答案

    参考答案:

    class EncoderLeakyReLU(nn.Module):
        def __init__(self, input_dim=784, h1_dim=512, h2_dim=256, latent_dim=32, negative_slope=0.01):
            super(EncoderLeakyReLU, self).__init__()
            self.fc1 = nn.Linear(input_dim, h1_dim)
            self.fc2 = nn.Linear(h1_dim, h2_dim)
            self.fc_mu = nn.Linear(h2_dim, latent_dim)
            self.fc_logvar = nn.Linear(h2_dim, latent_dim)
            self.negative_slope = negative_slope
    
        def forward(self, x):
            h1 = F.leaky_relu(self.fc1(x), negative_slope=self.negative_slope)
            h2 = F.leaky_relu(self.fc2(h1), negative_slope=self.negative_slope)
            mu = self.fc_mu(h2)
            logvar = self.fc_logvar(h2)
            return mu, logvar
    
    class DecoderLeakyReLU(nn.Module):
        def __init__(self, latent_dim=32, h2_dim=256, h1_dim=512, output_dim=784, negative_slope=0.01):
            super(DecoderLeakyReLU, self).__init__()
            self.fc3 = nn.Linear(latent_dim, h2_dim)
            self.fc4 = nn.Linear(h2_dim, h1_dim)
            self.fc5 = nn.Linear(h1_dim, output_dim)
            self.negative_slope = negative_slope
    
        def forward(self, z):
            D1 = F.leaky_relu(self.fc3(z), negative_slope=self.negative_slope)
            D2 = F.leaky_relu(self.fc4(D1), negative_slope=self.negative_slope)
            x_hat = torch.sigmoid(self.fc5(D2))
            return x_hat
    
    class VAELeakyReLU(nn.Module):
        def __init__(self, input_dim=784, h1_dim=512, h2_dim=256, latent_dim=32, negative_slope=0.01):
            super(VAELeakyReLU, self).__init__()
            self.encoder = EncoderLeakyReLU(input_dim, h1_dim, h2_dim, latent_dim, negative_slope)
            self.decoder = DecoderLeakyReLU(latent_dim, h2_dim, h1_dim, input_dim, negative_slope)
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + std * eps
    
        def forward(self, x):
            mu, logvar = self.encoder(x)
            z = self.reparameterize(mu, logvar)
            x_hat = self.decoder(z)
            return x_hat, mu, logvar
    

    观察训练过程的变化:

    • 梯度流动:LeakyReLU允许小的负梯度,通过避免“死亡ReLU”问题,提高梯度流动性。
    • 训练稳定性:可能观察到更稳定的训练过程,减少早期梯度消失的问题。
    • 生成质量:有时生成的图像质量可能有所提升,尤其是在深层网络中。
  3. 实现一个带有变种的损失函数,其中重构损失使用均方误差(MSE)代替二元交叉熵,并比较训练结果。

    提示: 修改损失函数中的重构部分,使用F.mse_loss

    查看答案

    参考答案:

    def loss_function_mse(x, x_hat, mu, logvar, beta=1.0):
        # 重构损失 - 均方误差
        MSE = F.mse_loss(x_hat, x, reduction='sum')
    
        # KL散度损失
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
        return MSE + beta * KLD
    
    # 在训练过程中使用新的损失函数
    def train_vae_mse(model, optimizer, train_loader, num_epochs, beta):
        model.train()
        for epoch in range(num_epochs):
            train_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = data.view(-1, input_dim).to(device)
                optimizer.zero_grad()
                x_hat, mu, logvar = model(data)
                loss = loss_function_mse(data, x_hat, mu, logvar, beta)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
    
            avg_loss = train_loss / len(train_loader.dataset)
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
    
    # 使用新的损失函数训练VAE
    model_mse = VAE(input_dim, h1_dim, h2_dim, latent_dim).to(device)
    optimizer_mse = optim.Adam(model_mse.parameters(), lr=learning_rate)
    train_vae_mse(model_mse, optimizer_mse, train_loader, num_epochs, beta)
    

    比较结果:

    • 重构质量:MSE损失倾向于生成更平滑的图像,而二元交叉熵可能生成更锐利的图像。
    • 训练稳定性:两种损失函数在训练稳定性方面可能没有显著差异,但具体表现取决于数据和模型架构。
    • 应用场景:MSE适用于连续值数据,尤其在图像重构任务中常用;二元交叉熵适用于二值化的数据。