变分自编码器 VAE
VAE 网络结构¶
VAE由编码器(Encoder)和解码器(Decoder)组成:
- 编码器:将输入数据\(x\)映射到潜在空间的分布参数,通常输出均值\(\mu\)和对数方差\(\log \sigma^2\)。
$$ \mu, \log \sigma^2 = \text{Encoder}(x) $$
- 潜在变量:从编码器输出的分布中采样得到潜在变量\(z\)。
$$ z \sim \mathcal{N}(\mu, \sigma^2) $$
- 解码器:根据潜在变量\(z\)重构输入数据\(\hat{x}\)。
$$ \hat{x} = \text{Decoder}(z) $$
注意,在这里我们发现,\(\text{Encoder}(x)\)层输出的两个值\(\mu\)和\(\log \sigma^2\),并不先验地代表着输入数据分布的均值和方差,那么究竟是什么因素逼迫他们具有了这样的性质呢
层 | 输入 | 操作 | 输出 |
---|---|---|---|
输入层 | 原始数据 \(X\) | 展平操作(如果是图像) | \(X_{flat} \in \mathbb{R}^{batch \times input\_dim}\) |
编码器部分 | |||
编码器隐藏层 1 | \(X_{flat}\) | \(H_1 = \text{ReLU}(X_{flat}W_1 + b_1)\) 其中 \(W_1 \in \mathbb{R}^{input\_dim \times h1}\) |
\(H_1 \in \mathbb{R}^{batch \times h1}\) |
编码器隐藏层 2 | \(H_1\) | \(H_2 = \text{ReLU}(H_1W_2 + b_2)\) 其中 \(W_2 \in \mathbb{R}^{h1 \times h2}\) |
\(H_2 \in \mathbb{R}^{batch \times h2}\) |
瓶颈层(Bottleneck) | \(H_2 \in \mathbb{R}^{batch \times h2}\) | 均值分支: \(\mu = H_2W_\mu + b_\mu\) 其中 \(W_\mu \in \mathbb{R}^{h2 \times latent\_dim}\) 标准差分支: \(\log \sigma^2 = H_2W_\sigma + b_\sigma\) 其中 \(W_\sigma \in \mathbb{R}^{h2 \times latent\_dim}\) 重参数化采样: \(\epsilon \sim \mathcal{N}(0,I)\) \(z = \mu + \sigma \odot \epsilon\) |
\(\mu \in \mathbb{R}^{batch \times latent\_dim}\) \(\sigma \in \mathbb{R}^{batch \times latent\_dim}\) \(z \in \mathbb{R}^{batch \times latent\_dim}\) |
解码器部分 | |||
解码器隐藏层 1 | \(z\) | \(D_1 = \text{ReLU}(zW_3 + b_3)\) 其中 \(W_3 \in \mathbb{R}^{latent\_dim \times h2}\) |
\(D_1 \in \mathbb{R}^{batch \times h2}\) |
解码器隐藏层 2 | \(D_1\) | \(D_2 = \text{ReLU}(D_1W_4 + b_4)\) 其中 \(W_4 \in \mathbb{R}^{h2 \times h1}\) |
\(D_2 \in \mathbb{R}^{batch \times h1}\) |
输出层 | \(D_2\) | \(\hat{X} = \text{sigmoid}(D_2W_5 + b_5)\) 其中 \(W_5 \in \mathbb{R}^{h1 \times input\_dim}\) |
\(\hat{X} \in \mathbb{R}^{batch \times input\_dim}\) |
损失计算 | |||
重构损失 | \(X, \hat{X}\) | \(\mathcal{L}_{recon} = \sum(X\log(\hat{X}) + (1-X)\log(1-\hat{X}))\) | 标量 |
KL散度损失 | \(\mu, \sigma\) | \(\mathcal{L}_{KL} = -\frac{1}{2}\sum(1 + \log(\sigma^2) - \mu^2 - \sigma^2)\) | 标量 |
总损失 | \(\mathcal{L}_{recon}, \mathcal{L}_{KL}\) | \(\mathcal{L}_{total} = \mathcal{L}_{recon} + \beta\mathcal{L}_{KL}\) 其中 \(\beta\) 是权重系数 |
标量 |
典型维度示例: - \(input\_dim = 784\) (28×28 MNIST图像) - \(h1 = 512\) - \(h2 = 256\) - \(latent\_dim = 32\)
VAE的总损失函数¶
\(L_{VAE} = L_{recon} + L_{KL}\)
- 重构损失 \(L_{recon}\)
\(L_{recon} = ||x - \hat{x}||^2\)
其中: - \(x\) 是输入数据 - \(\hat{x}\) 是重构的数据
- KL散度损失 \(L_{KL}\)
\(L_{KL} = \frac{1}{2}\sum_{i=1}^n(\sigma_i^2 + \mu_i^2 - 1 - \log(\sigma_i^2))\)
其中: - \(\mu_i\) 是编码器输出的均值向量的第i个元素 - \(\sigma_i\) 是编码器输出的标准差向量的第i个元素 - \(n\) 是潜在空间的维度
KL散度项的推导来源¶
KL散度的定义¶
对于两个概率分布\(P\)和\(Q\),KL散度(相对熵)定义为:
在变分自编码器(VAE)的情境下,我们有:
- \(P = \mathcal{N}(\mu, \sigma^2)\):编码器输出的高斯分布
- \(Q = \mathcal{N}(0, 1)\):标准正态分布
因此,我们需要计算:
高斯分布的表达式¶
首先,写出两个高斯分布的概率密度函数:
计算\(\log \frac{P(x)}{Q(x)}\)¶
将\(P(x)\)和\(Q(x)\)代入KL散度的定义:
简化上式:
进一步化简:
展开并整理¶
将\(-\frac{(x - \mu)^2}{2\sigma^2}\)展开:
因此:
合并同类项:
将\(\frac{x^2}{2} - \frac{x^2}{2\sigma^2}\)合并:
因此:
计算期望¶
KL散度涉及对\(P(x)\)的期望,因此我们需要计算:
将上式代入:
由于期望是线性的,可以分开计算各项的期望:
已知对于\(P(x) = \mathcal{N}(\mu, \sigma^2)\):
将这些代入:
简化各项:
- \(\frac{\mu^2}{\sigma^2} - \frac{\mu^2}{2\sigma^2} = \frac{\mu^2}{2\sigma^2}\)
- 展开最后一项:
将所有项合并:
合并\(\mu^2\)项:
因此:
进一步整理:
多维情形¶
在VAE中,潜在空间通常是多维的,假设潜在空间的维度为\(n\),且各维度之间相互独立。此时,总的KL散度是各维度KL散度的和:
这就是VAE中KL散度项的具体推导过程。
KL散度在VAE中的作用¶
我们之前推导过,对于高斯分布\(q(z|x) = \mathcal{N}(\mu, \sigma^2)\)和\(p(z) = \mathcal{N}(0,1)\),KL散度为:
这个KL散度项在损失函数中起到了以下作用:
-
正则化潜在空间:通过最小化KL散度,VAE被迫使得编码器输出的分布\(q(z|x)\)接近先验分布\(p(z)\)。这是为了确保潜在空间在不同样本之间具有良好的结构性和连续性,从而使得生成的数据具有多样性和连贯性。
-
引导隐藏层输出:编码器的隐藏层输出被设计为均值\(\mu\)和对数方差\(\log \sigma^2\),通过最小化KL散度,损失函数会对\(\mu\)和\(\sigma^2\)施加约束:
-
均值\(\mu\):KL散度中的\(\mu^2\)项鼓励均值接近0,因为最小化\(\mu^2\)会使得\(\mu\)尽可能小。
$$ \frac{1}{2} \mu^2 \quad \text{最小化时,} \quad \mu \rightarrow 0 $$
-
方差\(\sigma^2\):KL散度中的\(\sigma^2 - \log \sigma^2 - 1\)项在\(\sigma^2 = 1\)时达到最小值。这意味着损失函数会鼓励方差接近1。
$$ \frac{1}{2} (\sigma^2 - \log \sigma^2 - 1) \quad \text{最小化时,} \quad \sigma^2 \rightarrow 1 $$
具体的梯度驱动¶
在VAE的训练过程中,损失函数对\(\mu\)和\(\sigma^2\)的梯度驱动体现了KL散度和重构误差的双重作用: 1. KL散度损失**的梯度: - 对\(\mu\)的梯度:\(\frac{\partial KL}{\partial \mu} = \frac{\mu}{1}\) 这个梯度促使\(\mu\)减小,趋向于0。 - 对\(\sigma^2\)的梯度:\(\frac{\partial KL}{\partial \sigma^2} = \frac{1}{2}(1 - \frac{1}{\sigma^2})\) 这个梯度促使\(\sigma^2\)趋向于1。 通过反向传播,这些梯度引导编码器的参数调整,使得输出的\(\mu\)和\(\sigma^2\)逐渐逼近先验分布的参数(均值为0,方差为1)。 2. **重构误差损失(\(\mathcal{L}_{\text{recon}}\)):
这部分损失迫使模型生成的重构数据\(\hat{x}\)尽可能接近原始输入数据\(x\)。它驱动编码器学习到能够有效表示数据特征的潜在变量\(z\),从而赋予\(\mu\)和\(\sigma^2\)实际的统计意义。
双重作用的平衡¶
- **KL散度损失**赋予潜在变量先验分布,确保潜在空间的结构性和可采样性。
- **重构误差损失**赋予\(\mu\)和\(\sigma^2\)实际的统计意义,使它们能够捕捉输入数据的分布特征。 这种平衡设计使得VAE能够在保持潜在空间结构的同时,有效地进行数据表示和生成。\(\mu\)和\(\sigma^2\)不仅被逼近标准正态分布,还反映了输入数据在潜在空间中的分布特性。
双输出的必要性¶
编码器输出双参数(\(\mu\)和\(\sigma^2\))的设计有以下重要原因: 1. 表达不确定性:\(\mu\)表示潜在变量的中心位置,\(\sigma^2\)表示分布的扩散程度,共同描述了潜在变量的不确定性。 2. 重参数化技巧:通过\(z = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0,1)\),VAE能够在反向传播中有效地传递梯度到\(\mu\)和\(\sigma^2\)。
\(\mu\)和\(\log \sigma^2\)性质从何而来¶
回到我们最初的问题:是什么因素逼迫编码器输出的\(\mu\)和\(\log \sigma^2\)具有了表示输入数据分布均值和方差的性质? 答案在于VAE损失函数的巧妙设计和训练过程的动态平衡: 1. KL散度损失 - 为\(\mu\)和\(\sigma^2\)提供了一个先验约束,使它们倾向于标准正态分布(均值为0,方差为1)。这为潜在空间提供了一个统一的结构。 2. 重构误差损失 - 则要求通过这些参数采样得到的潜在变量\(z\)必须包含足够的信息来重构原始输入。这迫使\(\mu\)和\(\sigma^2\)捕获输入数据的实际分布特征。 3. 训练过程中的动态平衡: - 模型在最小化这两个损失项的过程中,找到了一个平衡点。在这个平衡点上,\(\mu\)和\(\sigma^2\)既满足了先验分布的约束,又能够有效地编码输入数据的分布特征。 因此,\(\mu\)和\(\sigma^2\)最终具有了表示输入数据分布特征的性质,是损失函数的双重约束和训练过程中的动态平衡共同作用的结果。它们不是简单地表示整个数据集的均值和方差,而是对每个输入样本在潜在空间中的分布进行参数化表示。
流形解释¶
从流形的角度来理解,VAE的工作原理可以解释如下:
概念 | 解释 |
---|---|
数据流形 | 高维输入数据通常位于一个低维流形上。例如,人脸图像虽然是高维的,但实际上可能只需要少数几个参数就能描述其主要特征。 |
潜在空间 | VAE的潜在空间可以看作是对这个低维流形的近似。\(\mu\)和\(\log \sigma^2\)共同定义了这个流形上的一个局部坐标系。 |
概率分布映射 | 对于每个输入\(x\),编码器输出的\(\mu(x)\)和\(\sigma^2(x)\)定义了潜在空间中的一个高斯分布。这可以理解为将输入数据点映射到潜在空间流形上的一个概率分布。 |
流形学习 | 通过训练,VAE学习到了一个从高维输入空间到低维潜在空间的平滑映射,这个映射捕获了数据的内在结构。 |
连续性和平滑性 | KL散度损失确保了潜在空间中相邻点的分布是平滑变化的,这有助于保持流形的连续性。 |
生成过程 | 当我们从潜在空间采样并通过解码器生成数据时,实际上是在这个学习到的低维流形上进行插值或外推。 |