Skip to content

变分自编码器 VAE

VAE 网络结构

VAE由编码器(Encoder)和解码器(Decoder)组成:

  • 编码器:将输入数据\(x\)映射到潜在空间的分布参数,通常输出均值\(\mu\)和对数方差\(\log \sigma^2\)

$$ \mu, \log \sigma^2 = \text{Encoder}(x) $$

  • 潜在变量:从编码器输出的分布中采样得到潜在变量\(z\)

$$ z \sim \mathcal{N}(\mu, \sigma^2) $$

  • 解码器:根据潜在变量\(z\)重构输入数据\(\hat{x}\)

$$ \hat{x} = \text{Decoder}(z) $$

注意,在这里我们发现,\(\text{Encoder}(x)\)层输出的两个值\(\mu\)\(\log \sigma^2\),并不先验地代表着输入数据分布的均值和方差,那么究竟是什么因素逼迫他们具有了这样的性质呢

输入 操作 输出
输入层 原始数据 \(X\) 展平操作(如果是图像) \(X_{flat} \in \mathbb{R}^{batch \times input\_dim}\)
编码器部分
编码器隐藏层 1 \(X_{flat}\) \(H_1 = \text{ReLU}(X_{flat}W_1 + b_1)\)
其中 \(W_1 \in \mathbb{R}^{input\_dim \times h1}\)
\(H_1 \in \mathbb{R}^{batch \times h1}\)
编码器隐藏层 2 \(H_1\) \(H_2 = \text{ReLU}(H_1W_2 + b_2)\)
其中 \(W_2 \in \mathbb{R}^{h1 \times h2}\)
\(H_2 \in \mathbb{R}^{batch \times h2}\)
瓶颈层(Bottleneck) \(H_2 \in \mathbb{R}^{batch \times h2}\) 均值分支:
\(\mu = H_2W_\mu + b_\mu\)
其中 \(W_\mu \in \mathbb{R}^{h2 \times latent\_dim}\)

标准差分支:
\(\log \sigma^2 = H_2W_\sigma + b_\sigma\)
其中 \(W_\sigma \in \mathbb{R}^{h2 \times latent\_dim}\)

重参数化采样:
\(\epsilon \sim \mathcal{N}(0,I)\)
\(z = \mu + \sigma \odot \epsilon\)
\(\mu \in \mathbb{R}^{batch \times latent\_dim}\)

\(\sigma \in \mathbb{R}^{batch \times latent\_dim}\)

\(z \in \mathbb{R}^{batch \times latent\_dim}\)
解码器部分
解码器隐藏层 1 \(z\) \(D_1 = \text{ReLU}(zW_3 + b_3)\)
其中 \(W_3 \in \mathbb{R}^{latent\_dim \times h2}\)
\(D_1 \in \mathbb{R}^{batch \times h2}\)
解码器隐藏层 2 \(D_1\) \(D_2 = \text{ReLU}(D_1W_4 + b_4)\)
其中 \(W_4 \in \mathbb{R}^{h2 \times h1}\)
\(D_2 \in \mathbb{R}^{batch \times h1}\)
输出层 \(D_2\) \(\hat{X} = \text{sigmoid}(D_2W_5 + b_5)\)
其中 \(W_5 \in \mathbb{R}^{h1 \times input\_dim}\)
\(\hat{X} \in \mathbb{R}^{batch \times input\_dim}\)
损失计算
重构损失 \(X, \hat{X}\) \(\mathcal{L}_{recon} = \sum(X\log(\hat{X}) + (1-X)\log(1-\hat{X}))\) 标量
KL散度损失 \(\mu, \sigma\) \(\mathcal{L}_{KL} = -\frac{1}{2}\sum(1 + \log(\sigma^2) - \mu^2 - \sigma^2)\) 标量
总损失 \(\mathcal{L}_{recon}, \mathcal{L}_{KL}\) \(\mathcal{L}_{total} = \mathcal{L}_{recon} + \beta\mathcal{L}_{KL}\)
其中 \(\beta\) 是权重系数
标量

典型维度示例: - \(input\_dim = 784\) (28×28 MNIST图像) - \(h1 = 512\) - \(h2 = 256\) - \(latent\_dim = 32\)

VAE的总损失函数

\(L_{VAE} = L_{recon} + L_{KL}\)

  1. 重构损失 \(L_{recon}\)

\(L_{recon} = ||x - \hat{x}||^2\)

其中: - \(x\) 是输入数据 - \(\hat{x}\) 是重构的数据

  1. KL散度损失 \(L_{KL}\)

\(L_{KL} = \frac{1}{2}\sum_{i=1}^n(\sigma_i^2 + \mu_i^2 - 1 - \log(\sigma_i^2))\)

其中: - \(\mu_i\) 是编码器输出的均值向量的第i个元素 - \(\sigma_i\) 是编码器输出的标准差向量的第i个元素 - \(n\) 是潜在空间的维度

KL散度项的推导来源

KL散度的定义

对于两个概率分布\(P\)\(Q\),KL散度(相对熵)定义为:

\[ KL(P \parallel Q) = \int_{-\infty}^{+\infty} P(x) \log \frac{P(x)}{Q(x)} dx \]

在变分自编码器(VAE)的情境下,我们有:

  • \(P = \mathcal{N}(\mu, \sigma^2)\):编码器输出的高斯分布
  • \(Q = \mathcal{N}(0, 1)\):标准正态分布

因此,我们需要计算:

\[ KL(\mathcal{N}(\mu, \sigma^2) \parallel \mathcal{N}(0,1)) = \int_{-\infty}^{+\infty} \mathcal{N}(x; \mu, \sigma^2) \log \frac{\mathcal{N}(x; \mu, \sigma^2)}{\mathcal{N}(x; 0,1)} dx \]

高斯分布的表达式

首先,写出两个高斯分布的概率密度函数:

\[ \mathcal{N}(x; \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left( -\frac{(x - \mu)^2}{2\sigma^2} \right) \]
\[ \mathcal{N}(x; 0,1) = \frac{1}{\sqrt{2\pi}} \exp\left( -\frac{x^2}{2} \right) \]

计算\(\log \frac{P(x)}{Q(x)}\)

\(P(x)\)\(Q(x)\)代入KL散度的定义:

\[ \log \frac{P(x)}{Q(x)} = \log \left( \frac{\frac{1}{\sqrt{2\pi\sigma^2}} \exp\left( -\frac{(x - \mu)^2}{2\sigma^2} \right)}{\frac{1}{\sqrt{2\pi}} \exp\left( -\frac{x^2}{2} \right)} \right) \]

简化上式:

\[ \log \frac{P(x)}{Q(x)} = \log \left( \frac{1}{\sqrt{\sigma^2}} \right) + \left( -\frac{(x - \mu)^2}{2\sigma^2} + \frac{x^2}{2} \right) \]

进一步化简:

\[ \log \frac{P(x)}{Q(x)} = -\frac{1}{2} \log \sigma^2 - \frac{(x - \mu)^2}{2\sigma^2} + \frac{x^2}{2} \]

展开并整理

\(-\frac{(x - \mu)^2}{2\sigma^2}\)展开:

\[ -\frac{(x - \mu)^2}{2\sigma^2} = -\frac{x^2 - 2\mu x + \mu^2}{2\sigma^2} = -\frac{x^2}{2\sigma^2} + \frac{\mu x}{\sigma^2} - \frac{\mu^2}{2\sigma^2} \]

因此:

\[ \log \frac{P(x)}{Q(x)} = -\frac{1}{2} \log \sigma^2 - \frac{x^2}{2\sigma^2} + \frac{\mu x}{\sigma^2} - \frac{\mu^2}{2\sigma^2} + \frac{x^2}{2} \]

合并同类项:

\[ \log \frac{P(x)}{Q(x)} = -\frac{1}{2} \log \sigma^2 + \frac{\mu x}{\sigma^2} - \frac{\mu^2}{2\sigma^2} + \left( \frac{x^2}{2} - \frac{x^2}{2\sigma^2} \right) \]

\(\frac{x^2}{2} - \frac{x^2}{2\sigma^2}\)合并:

\[ \frac{x^2}{2} \left(1 - \frac{1}{\sigma^2}\right) = \frac{x^2}{2} \left( \frac{\sigma^2 - 1}{\sigma^2} \right ) = \frac{\sigma^2 - 1}{2\sigma^2} x^2 \]

因此:

\[ \log \frac{P(x)}{Q(x)} = -\frac{1}{2} \log \sigma^2 + \frac{\mu x}{\sigma^2} - \frac{\mu^2}{2\sigma^2} + \frac{\sigma^2 - 1}{2\sigma^2} x^2 \]

计算期望

KL散度涉及对\(P(x)\)的期望,因此我们需要计算:

\[ KL(\mathcal{N}(\mu, \sigma^2) \parallel \mathcal{N}(0,1)) = \mathbb{E}_{P(x)} \left[ \log \frac{P(x)}{Q(x)} \right ] \]

将上式代入:

\[ KL = \mathbb{E}_{P(x)} \left[ -\frac{1}{2} \log \sigma^2 + \frac{\mu x}{\sigma^2} - \frac{\mu^2}{2\sigma^2} + \frac{\sigma^2 - 1}{2\sigma^2} x^2 \right ] \]

由于期望是线性的,可以分开计算各项的期望:

\[ KL = -\frac{1}{2} \log \sigma^2 + \frac{\mu}{\sigma^2} \mathbb{E}_{P(x)}[x] - \frac{\mu^2}{2\sigma^2} + \frac{\sigma^2 - 1}{2\sigma^2} \mathbb{E}_{P(x)}[x^2] \]

已知对于\(P(x) = \mathcal{N}(\mu, \sigma^2)\)

\[ \mathbb{E}_{P(x)}[x] = \mu \]
\[ \mathbb{E}_{P(x)}[x^2] = \mu^2 + \sigma^2 \]

将这些代入:

\[ KL = -\frac{1}{2} \log \sigma^2 + \frac{\mu^2}{\sigma^2} - \frac{\mu^2}{2\sigma^2} + \frac{\sigma^2 - 1}{2\sigma^2} (\mu^2 + \sigma^2) \]

简化各项:

  1. \(\frac{\mu^2}{\sigma^2} - \frac{\mu^2}{2\sigma^2} = \frac{\mu^2}{2\sigma^2}\)
  2. 展开最后一项:
\[ \frac{\sigma^2 - 1}{2\sigma^2} (\mu^2 + \sigma^2) = \frac{(\sigma^2 - 1)\mu^2}{2\sigma^2} + \frac{(\sigma^2 - 1)\sigma^2}{2\sigma^2} = \frac{(\sigma^2 - 1)\mu^2}{2\sigma^2} + \frac{\sigma^2 - 1}{2} \]

将所有项合并:

\[ KL = -\frac{1}{2} \log \sigma^2 + \frac{\mu^2}{2\sigma^2} + \frac{(\sigma^2 - 1)\mu^2}{2\sigma^2} + \frac{\sigma^2 - 1}{2} \]

合并\(\mu^2\)项:

\[ \frac{\mu^2}{2\sigma^2} + \frac{(\sigma^2 - 1)\mu^2}{2\sigma^2} = \frac{\mu^2}{2\sigma^2} (1 + \sigma^2 -1) = \frac{\mu^2}{2} \]

因此:

\[ KL = -\frac{1}{2} \log \sigma^2 + \frac{\mu^2}{2} + \frac{\sigma^2 - 1}{2} \]

进一步整理:

\[ KL = \frac{1}{2} \left( \mu^2 + \sigma^2 - 1 - \log \sigma^2 \right ) \]

多维情形

在VAE中,潜在空间通常是多维的,假设潜在空间的维度为\(n\),且各维度之间相互独立。此时,总的KL散度是各维度KL散度的和:

\[ KL = \frac{1}{2} \sum_{i=1}^{n} \left( \mu_i^2 + \sigma_i^2 - 1 - \log \sigma_i^2 \right ) \]

这就是VAE中KL散度项的具体推导过程。

KL散度在VAE中的作用

我们之前推导过,对于高斯分布\(q(z|x) = \mathcal{N}(\mu, \sigma^2)\)\(p(z) = \mathcal{N}(0,1)\),KL散度为:

\[ KL(q(z|x) \parallel p(z)) = \frac{1}{2} \left( \mu^2 + \sigma^2 - 1 - \log \sigma^2 \right ) \]

这个KL散度项在损失函数中起到了以下作用:

  1. 正则化潜在空间:通过最小化KL散度,VAE被迫使得编码器输出的分布\(q(z|x)\)接近先验分布\(p(z)\)。这是为了确保潜在空间在不同样本之间具有良好的结构性和连续性,从而使得生成的数据具有多样性和连贯性。

  2. 引导隐藏层输出:编码器的隐藏层输出被设计为均值\(\mu\)和对数方差\(\log \sigma^2\),通过最小化KL散度,损失函数会对\(\mu\)\(\sigma^2\)施加约束:

  3. 均值\(\mu\):KL散度中的\(\mu^2\)项鼓励均值接近0,因为最小化\(\mu^2\)会使得\(\mu\)尽可能小。

    $$ \frac{1}{2} \mu^2 \quad \text{最小化时,} \quad \mu \rightarrow 0 $$

  4. 方差\(\sigma^2\):KL散度中的\(\sigma^2 - \log \sigma^2 - 1\)项在\(\sigma^2 = 1\)时达到最小值。这意味着损失函数会鼓励方差接近1。

    $$ \frac{1}{2} (\sigma^2 - \log \sigma^2 - 1) \quad \text{最小化时,} \quad \sigma^2 \rightarrow 1 $$

具体的梯度驱动

在VAE的训练过程中,损失函数对\(\mu\)\(\sigma^2\)的梯度驱动体现了KL散度和重构误差的双重作用: 1. KL散度损失**的梯度: - 对\(\mu\)的梯度:\(\frac{\partial KL}{\partial \mu} = \frac{\mu}{1}\) 这个梯度促使\(\mu\)减小,趋向于0。 - 对\(\sigma^2\)的梯度:\(\frac{\partial KL}{\partial \sigma^2} = \frac{1}{2}(1 - \frac{1}{\sigma^2})\) 这个梯度促使\(\sigma^2\)趋向于1。 通过反向传播,这些梯度引导编码器的参数调整,使得输出的\(\mu\)\(\sigma^2\)逐渐逼近先验分布的参数(均值为0,方差为1)。 2. **重构误差损失\(\mathcal{L}_{\text{recon}}\)):

这部分损失迫使模型生成的重构数据\(\hat{x}\)尽可能接近原始输入数据\(x\)。它驱动编码器学习到能够有效表示数据特征的潜在变量\(z\),从而赋予\(\mu\)\(\sigma^2\)实际的统计意义。

双重作用的平衡

  • **KL散度损失**赋予潜在变量先验分布,确保潜在空间的结构性和可采样性。
  • **重构误差损失**赋予\(\mu\)\(\sigma^2\)实际的统计意义,使它们能够捕捉输入数据的分布特征。 这种平衡设计使得VAE能够在保持潜在空间结构的同时,有效地进行数据表示和生成。\(\mu\)\(\sigma^2\)不仅被逼近标准正态分布,还反映了输入数据在潜在空间中的分布特性。

双输出的必要性

编码器输出双参数(\(\mu\)\(\sigma^2\))的设计有以下重要原因: 1. 表达不确定性\(\mu\)表示潜在变量的中心位置,\(\sigma^2\)表示分布的扩散程度,共同描述了潜在变量的不确定性。 2. 重参数化技巧:通过\(z = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0,1)\),VAE能够在反向传播中有效地传递梯度到\(\mu\)\(\sigma^2\)

\(\mu\)\(\log \sigma^2\)性质从何而来

回到我们最初的问题:是什么因素逼迫编码器输出的\(\mu\)\(\log \sigma^2\)具有了表示输入数据分布均值和方差的性质? 答案在于VAE损失函数的巧妙设计和训练过程的动态平衡: 1. KL散度损失 - 为\(\mu\)\(\sigma^2\)提供了一个先验约束,使它们倾向于标准正态分布(均值为0,方差为1)。这为潜在空间提供了一个统一的结构。 2. 重构误差损失 - 则要求通过这些参数采样得到的潜在变量\(z\)必须包含足够的信息来重构原始输入。这迫使\(\mu\)\(\sigma^2\)捕获输入数据的实际分布特征。 3. 训练过程中的动态平衡: - 模型在最小化这两个损失项的过程中,找到了一个平衡点。在这个平衡点上,\(\mu\)\(\sigma^2\)既满足了先验分布的约束,又能够有效地编码输入数据的分布特征。 因此,\(\mu\)\(\sigma^2\)最终具有了表示输入数据分布特征的性质,是损失函数的双重约束和训练过程中的动态平衡共同作用的结果。它们不是简单地表示整个数据集的均值和方差,而是对每个输入样本在潜在空间中的分布进行参数化表示。

流形解释

从流形的角度来理解,VAE的工作原理可以解释如下:

概念 解释
数据流形 高维输入数据通常位于一个低维流形上。例如,人脸图像虽然是高维的,但实际上可能只需要少数几个参数就能描述其主要特征。
潜在空间 VAE的潜在空间可以看作是对这个低维流形的近似。\(\mu\)\(\log \sigma^2\)共同定义了这个流形上的一个局部坐标系。
概率分布映射 对于每个输入\(x\),编码器输出的\(\mu(x)\)\(\sigma^2(x)\)定义了潜在空间中的一个高斯分布。这可以理解为将输入数据点映射到潜在空间流形上的一个概率分布。
流形学习 通过训练,VAE学习到了一个从高维输入空间到低维潜在空间的平滑映射,这个映射捕获了数据的内在结构。
连续性和平滑性 KL散度损失确保了潜在空间中相邻点的分布是平滑变化的,这有助于保持流形的连续性。
生成过程 当我们从潜在空间采样并通过解码器生成数据时,实际上是在这个学习到的低维流形上进行插值或外推。