ICML 15 | Deep Unsupervised Learning using Nonequilibrium Thermodynamics#
论文链接:http://proceedings.mlr.press/v37/sohl-dickstein15.html
代码链接:https://github.com/Sohl-Dickstein/Diffusion-Probabilistic-Models
作者单位:斯坦福大学
概率模型(probabilistic models)一直在易处理与灵活性之间平衡
我们定义一个逐渐将一种分布转换为另一分布的马尔科夫链,生成式马尔科夫链即使用 diffusion 过程将一个简单的已知分布转换为目标数据分布
在此框架下模型用于估计单一 diffusion 过程中的微小扰动,从而降低处理难度;此外,任一平滑的目标分布均存在 diffusion 过程,因此此方法灵活度较高
前向过程(diffusio 过程):将目标数据分布转换为简单已知分布
逆向过程:在有限的时间步下,从简单分布中生成目标数据分布
前向过程 $q(x^{(0 \cdots T)})$#
目标数据分布 $q(x^{(0)})$,简单已知分布 $\pi(y)$,马尔可夫 diffusion 核 $T_\pi(y \mid y';\beta)$,$\beta$ 为扩散速率,前向过程可以表示为:
$$
\begin{aligned}
\pi(y) &= \int \mathrm{d}y' T_\pi(y \mid y';\beta) \pi(y') \\
q(x^{(t)} \mid x^{(t-1)}) &= T_\pi(x^{(t)} \mid x^{(t-1)};\beta_t) \\
q(x^{(0 \cdots T)}) &= q(x^{(0)}) \prod_{t=1}^{T} q(x^{(t)} \mid x^{(t-1)})
\end{aligned}
$$
后向过程 $p(x^{(0 \cdots T)})$#
- 条件概率公式:$p(x \mid y) = \frac{p(x,y)}{p(y)}$
链式展开:$p(x_1, \cdots ,x_n) = p(x_1)\prod_{i=2}^{n}p(x_i \mid x_1, \cdots ,x_{i-1})$ - 边缘概率公式:$p(x) = \int p(x,y) \mathrm{d}y$
- 贝叶斯公式:$p(x \mid y) = \frac{p(y \mid x) p(x)}{p(y)}$
对于高斯扩散,在时间步数 $T$ 足够大、每一步扩散速率 $\beta$ 足够小时,前向扩散过程接近连续扩散过程。此时其反向转移可以近似采用与前向转移相同的分布族进行建模。因此,若前向扩散变换 $q(x^{(t)} \mid x^{(t-1)})$ 是高斯形式,则反向生成变换 $p(x^{(t-1)} \mid x^{(t)})$ 也可被视为高斯分布。后向过程可表示为:
$$
\begin{aligned}
p(x^{(T)}) &= \pi(x^{(T)}) \\
p(x^{(0 \cdots T)}) &= p(x^{(T)}) \prod_{t=1}^{T} p(x^{(t-1)} \mid x^{(t)})
\end{aligned}
$$
对于高斯扩散,仅需学习反向马尔可夫变换中高斯分布的均值 $f_\mu(x^{(t)},t)$ 与方差 $f_\Sigma(x^{(t)},t)$ 即可,本文使用 MLP 对以上两函数进行学习
对于目标数据分布 $p(x^{(0)})$ 有:
$$
\begin{aligned}
p(x^{(0)}) &= \int \mathrm{d}x^{(1 \cdots T)} p(x^{(0 \cdots T)}) && 边缘概率公式\\
&= \int \mathrm{d}x^{(1 \cdots T)} p(x^{(0 \cdots T)}) \frac{q(x^{(1 \cdots T)} \mid x^{(0)})}{q(x^{(1 \cdots T)} \mid x^{(0)})}\\
&= \int \mathrm{d}x^{(1 \cdots T)} q(x^{(1 \cdots T)} \mid x^{(0)}) \frac{p(x^{(0 \cdots T)})}{q(x^{(1 \cdots T)} \mid x^{(0)})}\\
&= \int \mathrm{d}x^{(1 \cdots T)} q(x^{(1 \cdots T)} \mid x^{(0)}) \frac{p(x^{(T)}) \prod_{t=1}^{T} p(x^{(t-1)} \mid x^{(t)})}{\prod_{t=1}^{T} q(x^{(t)} \mid x^{(t-1)})} && 分子分母条件概率链式展开,结合马尔可夫性化简\\
&= \int \mathrm{d}x^{(1 \cdots T)} q(x^{(1 \cdots T)} \mid x^{(0)}) p(x^{(T)}) \prod_{t=1}^{T} \frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}\\
\end{aligned}
$$
训练目标#
- 似然:概率是根据规则预测结果,似然是根据结果推测规则
$P(x \mid \theta)$ 中固定 $\theta$,变化 $x$ 则为概率函数;固定 $x$(已知观测数据),$\theta$ 未知则为似然函数 $L(\theta \mid x)$ - Jensen 不等式:
若 $f$ 为凸函数,则 $f(\mathbb{E}[Z]) \le \mathbb{E}[f(Z)]$
若 $f$ 为凹函数,则 $f(\mathbb{E}[Z]) \ge \mathbb{E}[f(Z)]$ - 连续变量的熵(微分熵):$H_p(X) = -\int p(x) \log p(x) \mathrm{d}x$
条件熵:$H_p(Y|X) = - \int p(x,y) \log p(y|x) \mathrm{d}x\mathrm{d}y$
两个分布的交叉熵:$H(q,p) = -\int q(x) \log p(x) \mathrm{d}x$ - KL散度:$D_{KL}(P || Q) = \int p(x) \ln \frac{p(x)}{q(x)} \,dx$
最大化对数似然:
$$
\begin{aligned}
L &= \int \mathrm{d}x^{(0)} q(x^{(0)}) \log p(x^{(0)}) && 按照真实数据分布 q(x^{(0)}) 加权平均\\
&= \int \mathrm{d}x^{(0)} q(x^{(0)}) \log [\int \mathrm{d}x^{(1 \cdots T)} q(x^{(1 \cdots T)} \mid x^{(0)}) p(x^{(T)}) \prod_{t=1}^{T} \frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}]\\
&\ge K = \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [p(x^{(T)}) \prod_{t=1}^{T} \frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}] && Jensen 不等式\\
& 处理t=T\\
&= \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \sum_{t=1}^{T} \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}] + \int \mathrm{d}x^{(T)} q(x^{(T)}) \log p(x^{(T)})\\
&= \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \sum_{t=1}^{T} \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}] + \int \mathrm{d}x^{(T)} q(x^{(T)}) \log \pi(x^{(T)})\\
&= \sum_{t=1}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}] - H_p(x^{(T)}) && q 与 \pi 的交叉熵为常数且等于 p 的熵\\
\end{aligned}
$$
为避免边缘效应,设定 $p(x^{(0)} \mid x^{(1)}) = q(x^{(1)} \mid x^{(0)}) \frac{\pi(x^{(0)})}{\pi(x^{(1)})} = T_\pi(x^{(0)} \mid x^{(1)}; \beta_1)$,则有:
$$
\begin{aligned}
K &= \sum_{t=1}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}] - H_p(x^{(T)})\\
& 处理 t = 1;简化书写,令 A = \sum_{t=2}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}]\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}] + \int \mathrm{d}x^{(0)} \mathrm{d}x^{(1)} q(x^{(0)},x^{(1)}) \log [\frac{p(x^{(0)} \mid x^{(1)})}{q(x^{(1)} \mid x^{(0)})}] - H_p(x^{(T)})\\
&= A + \int \mathrm{d}x^{(0)} \mathrm{d}x^{(1)} q(x^{(0)},x^{(1)}) \log [\frac{q(x^{(1)} \mid x^{(0)}) \pi(x^{(0)})}{q(x^{(1)} \mid x^{(0)}) \pi(x^{(1)})}] - H_p(x^{(T)}) && 带入设定\\
&= A + \int \mathrm{d}x^{(0)} \mathrm{d}x^{(1)} q(x^{(0)},x^{(1)}) [\log \pi(x^{(0)}) - \log \pi(x^{(1)})] - H_p(x^{(T)})\\
&= A + \int \mathrm{d}x^{(0)} [\int \mathrm{d}x^{(1)} q(x^{(0)},x^{(1)})] \log \pi(x^{(0)}) - \int \mathrm{d}x^{(1)} [\int \mathrm{d}x^{(0)} q(x^{(0)},x^{(1)})] \log \pi(x^{(1)}) - H_p(x^{(T)})\\
&= A + \int \mathrm{d}x^{(0)} q(x^{(0)}) \log \pi(x^{(0)}) - \int \mathrm{d}x^{(1)} q(x^{(0)}) \log \pi(x^{(1)}) - H_p(x^{(T)}) && 逆向边缘概率公式\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)})}] - H_p(x^{(T)}) && 均为H_p(x^{(T)}),两项相消\\
& 写为KL散度形式;简化书写,令 B = \sum_{t=2}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t-1)} \mid x^{(t)},x^{(0)})}]\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t)} \mid x^{(t-1)},x^{(0)})}] - H_p(x^{(T)}) && 马尔可夫性质,加一项条件不影响概率\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t-1)} \mid x^{(t)},x^{(0)})} \frac{q(x^{(t-1)} \mid x^{(0)})}{q(x^{(t)} \mid x^{(0)})}] - H_p(x^{(T)}) && 贝叶斯公式\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t-1)} \mid x^{(t)},x^{(0)})}] + \sum_{t=2}^{T} [\int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log q(x^{(t-1)} \mid x^{(0)}) - \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log q(x^{(t)} \mid x^{(0)})] - H_p(x^{(T)})\\
&= B + \sum_{t=2}^{T} [\int \mathrm{d}x^{(0)}\mathrm{d}x^{(t-1)} q(x^{(0)},x^{(t-1)}) \log q(x^{(t-1)} \mid x^{(0)}) - \int \mathrm{d}x^{(0)}\mathrm{d}x^{(t)} q(x^{(0)},x^{(t)}) \log q(x^{(t)} \mid x^{(0)})] - H_p(x^{(T)}) && 逆向边缘概率公式\\
&= B + \sum_{t=2}^{T} [H_q(x^{(t)} \mid x^{(0)}) - H_q(x^{(t-1)} \mid x^{(0)})] - H_p(x^{(T)}) && 条件熵\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0 \cdots T)} q(x^{(0 \cdots T)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t-1)} \mid x^{(t)},x^{(0)})}] + H_q(x^{(T)} \mid x^{(0)}) - H_q(x^{(1)} \mid x^{(0)}) - H_p(x^{(T)})\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0)}\mathrm{d}x^{(t-1)}\mathrm{d}x^{(t)} q(x^{(0)},x^{(t-1)},x^{(t)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t-1)} \mid x^{(t)},x^{(0)})}] + H_q(x^{(T)} \mid x^{(0)}) - H_q(x^{(1)} \mid x^{(0)}) - H_p(x^{(T)}) && 逆向边缘概率公式\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0)}\mathrm{d}x^{(t-1)}\mathrm{d}x^{(t)} q(x^{(0)},x^{(t)})q(x^{(t-1)} \mid x^{(0)},x^{(t)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t-1)} \mid x^{(t)},x^{(0)})}] + H_q(x^{(T)} \mid x^{(0)}) - H_q(x^{(1)} \mid x^{(0)}) - H_p(x^{(T)}) && 逆向条件概率公式\\
&= \sum_{t=2}^{T} \int \mathrm{d}x^{(0)}\mathrm{d}x^{(t)} q(x^{(0)},x^{(t)}) \int \mathrm{d}x^{(t-1)} q(x^{(t-1)} \mid x^{(0)},x^{(t)}) \log [\frac{p(x^{(t-1)} \mid x^{(t)})}{q(x^{(t-1)} \mid x^{(t)},x^{(0)})}] + H_q(x^{(T)} \mid x^{(0)}) - H_q(x^{(1)} \mid x^{(0)}) - H_p(x^{(T)})\\
&= -\sum_{t=2}^{T} \int \mathrm{d}x^{(0)}\mathrm{d}x^{(t)} q(x^{(0)},x^{(t)}) D_{KL}(q(x^{(t-1)} \mid x^{(t)},x^{(0)}) || q(x^{(t-1)} \mid x^{(t)},x^{(0)})) + H_q(x^{(T)} \mid x^{(0)}) - H_q(x^{(1)} \mid x^{(0)}) - H_p(x^{(T)})\\
\end{aligned}
$$
以上,训练目标为寻找能够最大化此对数似然的下界的逆马尔可夫变换,即:
$$
\hat{p}(x^{(t-1)}|x^{(t)})=\argmax_{p(x^{(t-1)}|x^{(t)})}-\sum_{t=2}^{T} \int \mathrm{d}x^{(0)}\mathrm{d}x^{(t)} q(x^{(0)},x^{(t)}) D_{KL}(q(x^{(t-1)} \mid x^{(t)},x^{(0)}) || q(x^{(t-1)} \mid x^{(t)},x^{(0)})) + H_q(x^{(T)} \mid x^{(0)}) - H_q(x^{(1)} \mid x^{(0)}) - H_p(x^{(T)})
$$
$\beta$ 设置#
在高斯分布情况下,设定 $\beta_1$ 为较小初始值,$\beta_{2 \cdots T}$ 为可学习参数,同时保持噪声不变,从而实现可微
条件生成#
普通生成乘以条件约束即为条件生成:
$$
\tilde{p}(x^{(0)}) \propto p(x^{(0)})r(x^{(0)})
$$
要想条件化首先要将各时间步的边缘分布条件化:
$$
\tilde{p}(x^{(t)}) = \frac{1}{\tilde{z}_t} p(x^{(t)})r(x^{(t)})
$$
其中,$\tilde{z}_t$ 为归一化参数
随后再对后向扩散核条件化,原始后向扩散核为:
$$
p(x^{(t)}) = \int \mathrm{d}x^{(t+1)} p(x^{(t)} \mid x^{(t+1)})p(x^{(t+1)})
$$
条件化后为:
$$
\begin{aligned}
\tilde{p}(x^{(t)}) &= \int \mathrm{d}x^{(t+1)} \tilde{p}(x^{(t)} \mid x^{(t+1)})\tilde{p}(x^{(t+1)})\\
\frac{p(x^{(t)})r(x^{(t)})}{\tilde{z}_t} &= \int \mathrm{d}x^{(t+1)} \tilde{p}(x^{(t)} \mid x^{(t+1)}) \frac{p(x^{(t+1)})r(x^{(t+1)})}{\tilde{z}_{t+1}} && 带入条件化边缘概率\\
p(x^{(t)})\frac{}{} &= \int \mathrm{d}x^{(t+1)} \tilde{p}(x^{(t)} \mid x^{(t+1)}) \frac{\tilde{z}_t r(x^{(t+1)})}{\tilde{z}_{t+1}r(x^{(t)})}p(x^{(t+1)})\\
\end{aligned}
$$
与原始后向扩散核对比有,积分函数相等:
$$
\begin{aligned}
\tilde{p}(x^{(t)} \mid x^{(t+1)}) \frac{\tilde{z}_t r(x^{(t+1)})}{\tilde{z}_{t+1}r(x^{(t)})}p(x^{(t+1)}) &= p(x^{(t)} \mid x^{(t+1)})p(x^{(t+1)})\\
\tilde{p}(x^{(t)} \mid x^{(t+1)}) &= \frac{\tilde{z}_{t+1}r(x^{(t)})}{\tilde{z}_t r(x^{(t+1)})}p(x^{(t)} \mid x^{(t+1)})\\
\end{aligned}
$$
对于固定的 $x^{(t+1)}$ 上式为关于 $x^{(t)}$ 的条件分布,此时 $\frac{\tilde{z}_{t+1}}{\tilde{z}_t r(x^{(t+1)})}$ 为常数,故上式可写作:
$$
\tilde{p}(x^{(t)} \mid x^{(t+1)}) = \frac{1}{\tilde{z}_t(x^{(t+1)})}p(x^{(t)} \mid x^{(t+1)})r(x^{(t)})
$$
其中,$\tilde{z}_t(x^{(t+1)})$ 为依赖于 $x^{(t+1)}$ 的归一化参数
条件化后,后向扩散核仍为高斯分布,证明暂略
若 $r(x^{(t)})$ 平滑,则条件化可近似为改变后向扩散核的均值;若 $r(x^{(t)})$ 可以精确相乘,就直接做闭式乘法,这适用于图像修复
在选择 $r(x^{(t)})$ 时需要其在过程中缓慢变化,可从以下两种形式中选其一:
$$
\begin{aligned}
r(x^{(t)}) &= r(x^{(0)})\\
r(x^{(t)}) &= r(x^{(0)})^{\frac{T-t}{T}}\\
\end{aligned}
$$
后向过程的熵#
后向过程的熵 $H_q(x^{(t-1)} \mid x^{(t)})$ 有以下上下界,证明暂略:
$$
H_q(x^{(t)} \mid x^{(t-1)}) + H_q(x^{(t-1)} \mid x^{(0)}) - H_q(x^{(t)} \mid x^{(0)}) \le H_q(x^{(t-1)} \mid x^{(t)}) \le H_q(x^{(t)} \mid x^{(t-1)})
$$