DDPM中框架代码

这里是DDPM框架的代码,等我成功把他用在超分辨率中后我会将项目代码和视频发出来。想起来SRGANext的视频还一直没做完。。。

原始论文中直接使用β作为方差,我这里使用了真实方差

 

import torch
import torch.nn.functional as F

class DenoiseDiffusion:
    def __init__(self,eps_module,n_steps,device):
        self.eps_module = eps_module
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0) # 维度与alpha相同
        self.n_steps = n_steps

    def q_xt_x0(self,x0,t):

        '''
        均值:根号下alpha_bar * x0
        方差:1-alpha_bar
        得到扩散第t步的均值和方差
        '''
        mean = (self.alpha_bar[t-1]**0.5)*x0
        var = 1-self.alpha_bar[t-1]
        return mean,var

    def q_sample(self,x0,t,eps=None):
        '''
        根据参数重整化返回采样
        '''
        if not eps:
            eps = torch.randn_like(x0)
            mean,var = self.q_xt_x0(x0,t)
            return mean+(var**0.5)*eps

    def p_sample(self,xt,t):
        '''
        xt-1的分布
        '''
        eps_theta = self.eps_module(xt,t) # 得到预测的分布
        alpha_bar = self.alpha_bar[t-1] # alpha_bar[t]
        alpha_bar_pre = self.alpha_bar[t-2] # alpha_bar[t-1]
        alpha = self.alpha[t-1]
        beta = 1-alpha
        eps_coef = beta/(1-alpha_bar)**0.5 # 预测的分布的系数
        mean = 1/(alpha**0.5)*(xt-eps_coef*eps_theta)
        var = beta*(1.0-alpha_bar_pre)/(1-alpha_bar)
        eps = torch.randn(xt.shape, device=xt.device)
        return mean + (var ** .5) * eps

    def loss(self,x0,noise):
        batch_size = x0.shape[0]
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        if not noise:
            noise = torch.randn_like(x0)
        xt = self.q_sample(x0, t, eps=noise)
        eps_theta = self.eps_module(xt,t)
        return F.mse_loss(noise,eps_theta)

你可能感兴趣的:(python,数学建模,开发语言)