概率扩散模型
GAN的缺点
Diffusion 模型
这里先看前向的过程,其实就是不断输入数据并加入噪声,最后很快就变成了一个纯噪声
每个时刻都要添加高斯噪声,后一时刻都是由前一时刻的噪声得到
其实这个过程可以看做一个不断构建标签(噪声)的过程
如何得到 xt 时刻的分布?
xt 时刻的噪声依赖于 αt 这个变量, αt=1−βt,论文中范围是 0.0001 到 0.002, 也就是说 α 需要越来越小才行。 xt 的计算公式如下:
xt=αtxt−1+1−αtz1 z1 是一个高斯噪声, xt−1 是上一时刻的噪声。 论文中对于 αt 是从概率论的角度去解释的,但是其实可以把 αt 看做权重,这样会更好理解一些。
一开始加一点噪声就有效果,越往后加越多的噪声,最后就是纯噪声。但是现在我们只知道最后一个时刻的分布是由前一个时刻的分布得到的,但是整个序列要如何计算呢? 如果一个一个计算,速度就会非常慢,能不能从 x0 直接算出来 xt 呢?
我们把 xt−1 的公式带入到 xt 的公式中,得到:
xt−1=αt−1xt−2+1−αt−1z1xt=αtαt−1xt−2+αt1−αt−1z1+1−αtz1 其中每一步中加入的噪声都是服从高斯分布 z1,z2,…∼N(0,I) ,把上面的式子化简一下,得到:
xt=αtαt−1xt−2+(αt(1−αt−1)z2+1−αtz1)=αtαt−1xt−2+1−αtαt−1z2 z1 和 z2 都服从高斯分布,分别是N(0,1−αt) 和 N(0,at(1−αt−1)), 由于高斯分布满足下面的性质所以可以进行上面化简:
N(0,σ12I)+N(0,σ22I)∼N(0,(σ12+σ22)I) 不断的往里面套就能发现规律了,其实就是一个累乘的过程:
xt=αˉtx0+1−αˉtzt 这个公式告诉我们,任意时刻的分布都可以通过 x0 状态算出来,一步到位。这是我们的第一个核心公式。
现在我们可以加噪声了,现在我们需要求反向的过程
如何求解反向的过程
逆向过程我们需要使用贝叶斯公式:
q(xt−1∣xt,x0)=q(xt∣xt−1,x0)q(xt∣x0)q(xt−1∣x0) 公式中的三项都可以通过前面的公式得到:
q(xt−1∣x0)=αt−1x0+1−αt−1z∼N(αt−1x0,1−αt−1)q(xt∣x0)=αˉtx0+1−αˉtz∼N(αˉtx0,1−αˉt)q(xt∣xt−1,x0)=αtxt−1+1−αtz∼N(αtxt−1,1−αt) 根据标准正态分布的性质,我们可以得到,q(xt−1∣xt,x0) 也就是 exp(−21(βt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))
接下来继续进行化简:
=exp(−21(βt(xt−αtxt−1)2+1−αˉt−1(xt−1−αˉt−1x0)2−1−αˉt(xt−αˉtx0)2))=exp(−21(βtxt2−2αtxtxt−1+αtxt−12+1−αˉt−1xt−12−2αˉt−1x0xt−1+αˉt−1x02−1−αˉt(xt−αˉtx0)2))=exp(−21((βtαt+1−αˉt−11)xt−12−(βt2αtxt+1−αˉt−12αˉt−1x0)xt−1+C(xt,x0))) C是一个常数,我们可以忽略它,它不会影响我们的结果。我们的核心是求和 xt−1有关的,其他的现在都不需要关心。上面的步骤其实就是在做一个配方的操作,对标标准正态分布公式:
exp(−2σ2(x−μ)2)=exp(−21(σ21x2−σ22μx+σ2μ2)) 对比上下的式子我们就可以得到:
μ~t(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0 将 x0 和 xt 的转化式子 x0=αˉt1(xt−1−αˉtzt) 带入到上面的式子中,我们就可以得到:
μ~t=at1(xt−1−aˉtβtzt) 但是好像还有一个问题就是 zt 要怎么求解呢?其实我们在正向的过程中会知道 zt−1 和 zt 之间的关系,我们可以用一个神经网络来预测 zt,然后我们就可以得到 xt 的分布了。
zt 其实就是我们要估计的每个时刻的噪声,这家伙看起来没法直接求解,所以我们用一个神经网络来预测。还有一个比较神奇的事情是相关的论文里面居然都是用 Unet 来做这个事情的,我也不知道为什么,但是我觉得这个网络结构是不是有点太简单了,可能编码和解码的结果看起来比较舒服?
终极流程图
下面是前向和后向的过程,分别对应我们前面所说的: