利用 去噪分数匹配目标的加权和 进行训练,去噪指的是使用sde的方法就不需要自行补充噪声了。
本文的目的是解释如何对原始数据进行扰动。 from https://yang-song.github.io/blog/2021/score/
首先,挑选一个随机过程(SDE)对原始数据分布 p 0 p_0 p0进行扰动得到扰动后数据的概率密度分布 p t p_t pt。
本文选择的随机过程为:
d x = σ t d w , t ∈ [ 0 , 1 ] d{\bf x} = \sigma^td{\bf w}, \ t\in[0,1] dx=σtdw, t∈[0,1]
在这种情况下,扰动后数据的概率密度分布 p t p_t pt,在原始数据下的条件概率分布为:
p 0 t ( x ( t ) ∣ x ( 0 ) ) = N ( x ( t ) ; x ( 0 ) , 1 2 log σ ( σ 2 t − 1 ) I ) p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0)) = \mathcal{N}\bigg(\mathbf{x}(t); \mathbf{x}(0), \frac{1}{2\log \sigma}(\sigma^{2t} - 1) \mathbf{I}\bigg) p0t(x(t)∣x(0))=N(x(t);x(0),2logσ1(σ2t−1)I)
关于这个函数的解释是,使用参数$ \frac{1}{2\log \sigma}(\sigma^{2t} - 1) 作 为 我 们 的 权 重 函 数 , 即 作为我们的权重函数,即 作为我们的权重函数,即\lambda(t) = \frac{1}{2 \log \sigma}(\sigma^{2t} - 1)$.
当参数 σ \sigma σ变得非常大的时候,其中的先验分布 p t = 1 p_{t=1} pt=1,也就是最终扰动后的数据分布就可以变成一个正太分布:
∫ p 0 ( y ) N ( x ; y , 1 2 log σ ( σ 2 − 1 ) I ) d y ≈ N ( x ; 0 , 1 2 log σ ( σ 2 − 1 ) I ) , \int p_0(\mathbf{y})\mathcal{N}\bigg(\mathbf{x}; \mathbf{y}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg) d \mathbf{y} \approx \mathbf{N}\bigg(\mathbf{x}; \mathbf{0}, \frac{1}{2 \log \sigma}(\sigma^2 - 1)\mathbf{I}\bigg), ∫p0(y)N(x;y,2logσ1(σ2−1)I)dy≈N(x;0,2logσ1(σ2−1)I),
直观地说,这个SDE通过一个变种函数 1 2 l o g σ ( σ 2 t − 1 ) \frac1{2\ log\ \sigma}(\sigma^{2t}-1) 2 log σ1(σ2t−1)帮助我们捕获了高斯扰动的数据变量集合(连续统continuum),即 x ( t ) x(t) x(t)。这个数据变量集合可以帮助我们逐渐将原始数据分布 p 0 p_0 p0变成了一个简单的高斯分布 p 1 p_1 p1,也就是t=1时候的分布。
1) 对t进行连续采样
# 对时间特征t进行均匀采样
random_t = torch.rand(x.x.shape[0]//30, device=device) * (1. - eps) + eps # 防止采样到0
2)定义权重函数
可以看到,这里定义的权重函数就是作者在上面提到的 λ ( t ) \lambda(t) λ(t)函数。
def marginal_prob_std(t, sigma):
# t = torch.tensor(t, device=device)
return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma))
3)对数据进行扰动
# 表征时间的特征t, 从0到1上进行均匀采样
random_t = torch.rand(batchsize, device=device) * (1. - eps) + eps # 这里的eps是为了防止采样到t=0
# 构造一个与原始数据结构一样的向量,并在[0,1)上进行均匀采样。
z = torch.randn_like(x.x)
# 利用前面均匀采样的时间特征t,求得权重函数的值,这个权重函数的目的就是为了使得t=1时的扰动数据达到一个正太分布的结果。重复30遍的目的是因为一轮训练中设置的batch_size = 30
std = marginal_prob_std_func(random_t).repeat(1, 30).view(-1, 1)
# 这里将噪声与标准差相乘,
perturbed_x = copy.deepcopy(x)
perturbed_x.x += z * std
4)利用扰动的数据进行训练
需要补充一下,为了训练积分函数模型,目前的目标函数变成了下面这个样子:
E t ∈ u ( 0 , T ) E p t ( x ) [ λ ( t ) ∣ ∣ ∇ x l o g p t ( x ) − s θ ( x , t ) ∣ ∣ 2 2 ] \mathbb{E}_{t\in u(0,T)}\mathbb{E}_{p_t(x)}[\lambda(t)||\nabla_xlog\ p_t(x)-s_\theta(x,t)||_2^2] Et∈u(0,T)Ept(x)[λ(t)∣∣∇xlog pt(x)−sθ(x,t)∣∣22]
这里是最基本的目标函数的样子:
E p ( x ) [ ∣ ∣ ∇ x l o g p ( x ) − s θ ( x ) ∣ ∣ 2 2 ] = ∫ p ( x ) ∣ ∣ ∇ x l o g p ( x ) − s θ ( x ) ∣ ∣ 2 2 d x . \mathbb{E}_{p(x)}[{||\nabla_xlog\ p(x)\ -\ s_\theta(x)||}_2^2]\ =\ \int\ p(x)||\nabla_x\ log\ p(x)\ -\ s_\theta(x)||_2^2dx. Ep(x)[∣∣∇xlog p(x) − sθ(x)∣∣22] = ∫ p(x)∣∣∇x log p(x) − sθ(x)∣∣22dx.
为了估计这个目标函数,需要如下估计,即使用Score Matching的方法进行估计(Hyvärinen 2005):
可以看到,去估计如下的目标函数是可以达到的。
E p d a t a ( x ) [ 1 2 ∣ ∣ s θ ( x ) ∣ ∣ 2 2 + t r a c e ( ∇ x s θ ( x ) ) ] \mathbb{E}_{p_{data}(x)}[\frac12||s_\theta(x)||_2^2+trace(\nabla_xs_\theta(x))] Epdata(x)[21∣∣sθ(x)∣∣22+trace(∇xsθ(x))]
具体上,体现在代码上,用的是如下的公式:
1 N ∑ i = 1 N [ 1 2 ∣ ∣ s θ ( x i ) ∣ ∣ 2 2 + t r a c e ( ∇ x s θ ( x i ) ) ] ≈ 1 N ∑ i = 1 N [ 1 2 ∣ ∣ s θ ( x i ) ∣ ∣ 2 2 + t r a c e ( ∇ x s θ ( x i ) ) \frac1N\sum^N_{i=1}[\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i))] \\ \approx \frac1N \sum_{i=1}^N [\frac12||s_\theta(x_i)||_2^2+trace(\nabla_xs_\theta(x_i)) N1i=1∑N[21∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))]≈N1i=1∑N[21∣∣sθ(xi)∣∣22+trace(∇xsθ(xi))
# 计算积分函数的值
output = model(perturbed_x, random_t)
# score matching的损失函数,与上式不一致的原因在于,本文的目标函数中还有一个参数\lambda(t),所以表现为如下的形式。
loss_ = torch.mean(torch.sum(((output * std + z)**2).view(batch_size, -1)), dim=-1)
# 一轮训练之后,将score matching的目标函数的结果返回
return loss_
♂️ 我有一个问题,这个目标函数是怎么推理得到的呀?