训练及推理步骤如下
训练:
① 准备真实图片样本数据
② 根据递推公式直接加噪到第t步,得到第t步数据xt,并记录真实加噪噪声noise
③ 根据第t步数据xt 和t,喂给神经网络得到预测的 噪声predict_noise
④ 真实噪声noise 和预测噪声predict_noise进行损失函数计算
采样:
① 随机生成第噪声数据,作为第t步的数据 xt
② 根据公式q(x_{t-1} | x_t, x_0),得到x{t-1}步数据
③ 直到②计算到x0时,即得到最终的x0数据
重点公式代码实现如下
beta=torch.linspace(beta_start, beta_end, timesteps)
alphas=1-betas
sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
1 正向传播时,随机给noise,通过下式计算真实xt
x t = α ˉ t ∗ x 0 + 1 − α ˉ t z t x_t=\sqrt{\bar \alpha_t}*x_{0}+\sqrt{1-\bar \alpha_t}z_t xt=αˉt∗x0+1−αˉtzt
sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
2 根据上面真实x_t,和时间步t,计算模型预测的噪声,(模型选择unet网络)
predicted_noise = denoise_model(x_noisy, t)
3 损失函数
最小化noise与predicted_noise距离
loss = F.mse_loss(noise, predicted_noise)
随机x_t ——> 根据 q(x_{t-1} | x_t, x_0) 公式,采样x_{t-1}步数——>递推最终得到第x0步数据。
x_t = torch.randn(shape, device=device)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
均值mu为
μ ˉ t = 1 α t ( x t − 1 − α t 1 − α ˉ t z t ) \bar \mu_t=\frac{1 } {\sqrt \alpha_{t}} (x_t -\frac{1-\alpha_t } {\sqrt{1- \bar \alpha_{t}}} z_t) μˉt=αt1(xt−1−αˉt1−αtzt)
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)
第t-1步数据
根据上面均值和方差进行采样得到第t-1步数据
noise = torch.randn_like(x)
x_(t-1)=model_mean + torch.sqrt(posterior_variance_t) * noise