基于分数的的扩散模型,
原文章:Score-Based Generative Modeling through Stochastic Differential Equations
一、直观理解
将反向降噪过程中的降噪建模成一个粒子随机运动过程,然后用langevin方程描述这个粒子随机运动,并求出一个稳定解,得出降噪过程噪点运动的物理规律,从而引导噪声的分布向原数据分布的特征进行运动,从而达到生成数据的目的。
二、正向扩散过程
随机微分方程建模(从离散到连续)
思路依然是基于DDPM的分解方法,将噪音添加的过程分为 $T$ 步。只是DDPM中这个步数 $T$ 是人为设计的、离散的,按照直观的理解,这个噪音添加包括降噪的过程应该是连续的,为了消除这个人为的影响。
这里就用一个随机微分方程(SDE)进行建模,使得diffusion的正向、反向过程在时间维度上达到连续。
每一次对数据进行噪声的添加,如此一次性的、离散的行为,用随机微分方程来表达可以写作:
其中:
- $\boldsymbol{f}_t\left(\boldsymbol{x}_t\right) \Delta t$ 表示的是确定项;
- $g_t \sqrt{\Delta t} \boldsymbol{\varepsilon}$ 表示的是随机项;
由于 $\boldsymbol{\varepsilon} \sim \mathcal{N}(\mathbf{0}, \boldsymbol{I})$ ,噪声一直服从一个标准正态分布。与此同时需要保证随机效应在这个过程之中一直存在,因此关于随机项的系数是个 $\sqrt{\Delta t}$
为了使其具备连续性,对上式取极限,使得 $\Delta t \rightarrow 0$ ,可得其连续的扩散过程的微分:
在这种情况下,扩散过程不必告知其扩散步数,只需要看关于时间 $t$ 的微分能取多小。
好处
- 在分析这个扩散过程的时候,可以利用连续的随机微分方程对其进行建模分析,使其理论上更可信。
- 在代码实现的时候,可以参照第一个式子,选取合适的离散化方案,就可以对其进行数值计算。
整体上实现了将理论分析与程序实现分离的功能。
三、反向降噪过程
由于没有办法直接写出反向过程的表达式,所以这里倾向于先写出正向的概率公式,再利用贝叶斯公式得到降噪过程的表达式。
对于一个极短的时间间隔 $\Delta t$ 内的正向过程,可以用条件概率公式表达为:
后面只写相关不写等式是为了略去一些常数项参数的干扰。这里利用贝叶斯定理,反向过程可以被表征为:
带入正向过程得到的相关系数,可以得到反向过程相关于:
此时需要注意到:
为了让这个过程连续,$\Delta t$ 需要足够小。而当 $\Delta t$ 足够小时,为了使得概率模型 $p\left(\boldsymbol{x}_{t+\Delta t} \mid \boldsymbol{x}_t\right)\neq 0$ ,也即使得其概率明显大于0,成为值得被考量的显著事件。
为了满足上述情况下的各种需求,只有令 $\boldsymbol{x}_t$ 和 $\boldsymbol{x}_{t+\Delta t}$ 足够接近时,$\frac{\left|\boldsymbol{x}_{t+\Delta t}-\boldsymbol{x}_t-\boldsymbol{f}_t\left(\boldsymbol{x}_t\right) \Delta t\right|^2}{2 g_t^2 \Delta t}$ 才会趋于0,从而使得 $\exp \left(-\frac{\left|\boldsymbol{x}_{t+\Delta t}-\boldsymbol{x}_t-\boldsymbol{f}_t\left(\boldsymbol{x}_t\right) \Delta t\right|^2}{2 g_t^2 \Delta t}\right)$ 趋于1,使得该概率模型的值明显不等于0。
因此针对 $\boldsymbol{x}_t$ 和 $\boldsymbol{x}_{t+\Delta t}$ 的关联进行数学描述,此处使用对数函数的一阶泰勒展开:
将这个式子作为一个结论,回代到反向过程的数学描述中:
$\text { 当 } \Delta t \rightarrow 0 \text { 时, } \mathcal{O}(\Delta t) \rightarrow 0 \text { 不起作用, 因此: }$
此处因为 $\Delta t \rightarrow 0$ ,$f_t(\cdot) \sim f_{t+\Delta t}(\cdot)$,同理 $g^2_t(\cdot) \sim g^2_{t+\Delta t}(\cdot)$ ,其余都是这么近似将下标从 $t$ 替换成 $t+\Delta t$ 。
将上述表达式凑成一个高斯分布的话,我们就可以得知,反向降噪概率模型 $p\left(\boldsymbol{x}_t \mid \boldsymbol{x}_{t+\Delta t}\right)$ 可以近似成一个高斯分布,参数为:
- 均值:
- 协方差: $g_{t+\Delta t}^2 \Delta t \boldsymbol{I}$
再度取 $\Delta t \rightarrow 0$ ,利用SDE对其建模,可以得到:
这就是降噪过程的SDE。
分数匹配(Score Matching)
1)从连续出发,再次的离散化
既然已经得到了逆向的SDE,只要再知道 $\nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x})$(分数(score)) 就可以将SDE再度离散化,实现一步一步离散化的去噪:
描述的是 $\boldsymbol{x}_t$ 和 $\boldsymbol{x}_{t+\Delta t}$ 之间的差距,对比一下正向过程的:
$\nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x})$ 中,$p_t(\boldsymbol{x})$ 等价于前面的 $p\left(\boldsymbol{x}_t\right)$ ,表征扩散到t时刻时的边缘分布。为了得知 $\nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x})$ ,先需要得知 $p\left(\boldsymbol{x}_t\right)$ 。
欲知 $p\left(\boldsymbol{x}_t\right)$ 可以构建一个条件概率 $p\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)$ ,使得 $p\left(\boldsymbol{x}_t\right)$ 可以通过 $p\left(\boldsymbol{x}_t\right)=\int p\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right) \tilde{p}\left(\boldsymbol{x}_0\right) d \boldsymbol{x}_0$ 而被求得。当离散SDE中的 $\boldsymbol{f}_t(\boldsymbol{x})$ 是关于 $x$ 的线性函数的话它就可以被求出解析解:
如此就可以写出:
带入 $\nabla_{\boldsymbol{x}} \log p_t(\boldsymbol{x})$ 中可得:
这个数学解析角度的好处:
- 因为 $p\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)$ 的解析解可以求得(基于上述线性假设)
- 因为其形式形似加权平均
- 所以可以进行直接计算
缺点: - 计算量太大
需要对全体训练样本计算加权平均 - 泛化能力不够
只用到了训练样本
为了改善这些缺点,也即加快计算速度、提高泛化能力,这里构建一个神经网络(分数网络)$s_\theta\left(x_t, t\right)$ 对 $\nabla_{\boldsymbol{x}_t} \log p\left(\boldsymbol{x}_t\right)$ 进行估计。
2)分数匹配
需要让 $s_\theta\left(x_t, t\right)$ 对 $\nabla_{\boldsymbol{x}_t} \log p\left(\boldsymbol{x}_t\right)$ 的估计越准确越好,我们需要设计一个优化目标。
这个优化目标灵感来自对某一个样本数据的均值进行估计的目标:
很容易可以知道,在最小化了 $|\boldsymbol{\mu}-\boldsymbol{x}|^2$ 的均值之后,此时的 $\mu$ 就无限接近于 $x$ 的均值。
同样的,对 $\nabla_{\boldsymbol{x}_t} \log p\left(\boldsymbol{x}_t\right)$ 进行估计,根据上面的描述,等价于对 $\nabla_{\boldsymbol{x}_t} \log p\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)$ 的加权平均的估计,即估计:
分母部分的 $\mathbb{E}_{\boldsymbol{x}_0}\left[p\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)\right]$ 是一个常量不含参,起到的作用是调节loss的权重,为了简化计算将其省略。如此最终的损失函数便是:
3) 求解SDE
求解思路:对正向过程
为使得噪音越加越多,得到从 $t=0$ 到 $t=1$ 过程的边界条件为:
再对每一个被离散化的正向过程进行分析,得到每一次离散化的加噪过程可以被概率模型描述为:
其中:
- $q\left(\boldsymbol{x}_{t+\Delta t} \mid \boldsymbol{x}_0\right)$
- 概率模型表达式: $\mathcal{N}\left(\boldsymbol{x}_t ; \bar{\alpha}_{t+\Delta t} \boldsymbol{x}_0, \bar{\beta}_{t+\Delta t}^2 \boldsymbol{I}\right)$
- 采样方式: $\boldsymbol{x}_{t+\Delta t}=\bar{\alpha}_{t+\Delta t} \boldsymbol{x}_0+\bar{\beta}_{t+\Delta t} \boldsymbol{\varepsilon}$
- $q\left(x_t \mid x_0\right)$
- 概率模型表达式: $\mathcal{N}\left(\boldsymbol{x}_t ; \bar{\alpha}_t \boldsymbol{x}_0, \bar{\beta}_t^2 \boldsymbol{I}\right)$
- 采样方式: $\boldsymbol{x}_t=\bar{\alpha}_t \boldsymbol{x}_0+\bar{\beta}_t \boldsymbol{\varepsilon}_1$
- $q\left(\boldsymbol{x}_{t+\Delta t} \mid \boldsymbol{x}_t\right)$
- 概率模型表达式: $\mathcal{N}\left(\boldsymbol{x}_{t+\Delta t} ;\left(1+f_t \Delta t\right) \boldsymbol{x}_t, g_t^2 \Delta t \boldsymbol{I}\right)$
- 采样方式: $\boldsymbol{x}_{t+\Delta t}=\left(1+f_t \Delta t\right) \boldsymbol{x}_t+g_t \Delta t \boldsymbol{\varepsilon}_2$
- $\int q\left(\boldsymbol{x}_{t+\Delta t} \mid \boldsymbol{x}_t\right) q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right) d \boldsymbol{x}_t$
- 采样方式: $\begin{aligned} & \boldsymbol{x}_{t+\Delta t} \\=&\left(1+f_t \Delta t\right) \boldsymbol{x}_t+g_t \sqrt{\Delta t} \boldsymbol{\varepsilon}_2 \\=&\left(1+f_t \Delta t\right)\left(\bar{\alpha}_t \boldsymbol{x}_0+\bar{\beta}_t \boldsymbol{\varepsilon}_1\right)+g_t \sqrt{\Delta t} \boldsymbol{\varepsilon}_2 \\=&\left(1+f_t \Delta t\right) \bar{\alpha}_t \boldsymbol{x}_0+\left(\left(1+f_t \Delta t\right) \bar{\beta}_t \boldsymbol{\varepsilon}_1+g_t \sqrt{\Delta t} \boldsymbol{\varepsilon}_2\right) \end{aligned}$
根据 $d x=f_t x d t+g_t d w$ 求解得到:
得知了每一次正向过程噪声添加的噪声系数在微分尺度下的表达,这时候再让 $\Delta t \rightarrow 0$
又因为:$\bar{\alpha}_t^2+\bar{\beta}_t^2=1$ ,以及 $x_t=\bar{\alpha}_t x_0+\bar{\beta}_t \varepsilon$ ,得到score的表达式:
由于分数网络是对上式进行的匹配,上式表明score正比于负的噪声,所以分数网络也可以被写作:
最后优化目标在SDE得解之后可以被写作: