手撕Diffusion系列 - 第八期(end) - Diffusion推理

手撕Diffusion系列 - 第八期(end) - Diffusion推理

目录

  • 手撕Diffusion系列 - 第八期(end) - Diffusion推理
  • DDPM 原理图
  • DDPM 反向去噪
  • DIffusion 推理代码
    • Part1 引入相关库函数
    • Part2 定义去噪函数
    • Part3 测试
  • 参考

DDPM 原理图

​ DDPM包括两个过程:前向过程(forward process)反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process),如下图所示。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成图片。

手撕Diffusion系列 - 第八期(end) - Diffusion推理_第1张图片

DDPM 整体大概流程

​ 图中,由高斯随机噪声 x T x_T xT 生成原始图片 x 0 x_0 x0 为反向过程,反之为前向过程(噪音扩散)。

DDPM 反向去噪

前向过程是将原始图片变成随机噪声,而反向过程就是通过预测噪声 ϵ ϵ ϵ ,将随机噪声 x T x_T xT 逐步还原为原始图片 x 0 x_0 x0 ,如下图所示。

手撕Diffusion系列 - 第八期(end) - Diffusion推理_第2张图片

反向扩散的过程(去噪)

反向过程公式表示如下:
x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1}=\frac{1}{\sqrt\alpha_t}(x_t-{\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}}} \epsilon_θ(x_t,t))+\sigma_tz xt1=α t1(xt1αˉt 1αtϵθ(xt,t))+σtz
其中 ϵ θ ϵ_θ ϵθ 是噪声估计函数(一般使用 Uet 模型),输入时间步数t和t时刻的噪声图( x t x_t xt)估计真实噪声 ϵ , θ θ θ 是模型训练的参数, z ∼ N ( 0 , 1 ) z∼N(0,1) zN(0,1) σ t z σ_tz σtz 表示的是预测噪声和真实噪声的误差(也就是为什么,去噪之后还得加上一个小噪声,好像是为了泛化性考虑,从而最终还原出来的了原图)。DDPM的关键就是训练噪声估计模型 ϵ θ ( x t , t ) ϵθ(xt,t) ϵθ(xt,t) ,用于估计真实的噪声 ϵ ϵ ϵ

需要注意的一个点:

  • α t \alpha_t αt 是随着t的增加不断的减小的一个过程,从加噪的过程来看,每次加噪的程度是越来越大的,也就模糊的越来越快。

  • σ t z \sigma_tz σtz 这个z就是标准正态分布,但是这个 σ t \sigma_t σt指的是标准差,并且计算公式如下:

  • σ t ( t ) = ( 1 − α t ) ( 1 − α ˉ t − 1 ) 1 − α ˉ t \sigma_t(t)=\sqrt\frac{(1-\alpha_t)(1-\bar\alpha_{t-1})}{1-\bar\alpha_t} σt(t)=1αˉt(1αt)(1αˉt1)

DIffusion 推理代码

Part1 引入相关库函数

# 该模块主要实现的是对图像进行去噪的测试。
'''
# 首先第一步,引入相关的库函数
'''

import torch
from torch import nn
from config import *
from diffusion import alpha_t, alpha_bar
from dataset import *
import matplotlib.pyplot as plt
from diffusion import forward_diffusion

Part2 定义去噪函数

'''
# 第二步定义一个去噪的函数
'''


def backward_denoise(net,batch_x_t):
    # 首先计算所需要的数据,方差variance,也就公式里面的beta_t
    alpha_bar_late = torch.cat((torch.tensor([1.0]) , alpha_bar[:-1]) ,dim=0)
    variance=(1-alpha_t)*(1-alpha_bar_late)/(1-alpha_bar)
    # 得到方差后,开始去噪
    net.eval() # 开启测试模式
    # 记录每次得到的图像
    steps=[batch_x_t]
    for t in range(T-1,-1,-1):
        # 初始化当前每张图像对应的时间状态
        batch_t=torch.full(size=(batch_x_t.size()[0],),fill_value=t) # 表示此时的时间状态 (batch,)
        # 预测噪声
        batch_noise_pre=net(batch_x_t,batch_t) # (batch,channel,iamg,imag)

        # 开始去噪(需要注意一个点,就是去噪的公式,在t不等于0和等于0是不一样的,先进行都需要处理部分也就是添加噪声前面的均值部分)
        # 同时记得要统一维度,便于广播
        reshape_size=(batch_t.size()[0],1,1,1)
        # 先取出对应的数值
        alpha_t_batch=alpha_t[batch_t]
        alpha_bar_batch=alpha_bar[batch_t]
        variance_batch=variance[batch_t]
        # 计算前面的均值
        batch_mean_t=1/torch.sqrt(alpha_t_batch).reshape(*reshape_size)\
                     *(batch_x_t-(1-alpha_t_batch.reshape(*reshape_size))*batch_noise_pre/torch.sqrt(1-alpha_bar_batch.reshape(*reshape_size)))

        # 分类,看t的值,判断是否添加噪声
        if t!=0:
            batch_x_t=batch_mean_t\
                      +torch.sqrt(variance_batch.reshape(*reshape_size))\
                      *torch.randn_like(batch_x_t)
        else:
            batch_x_t=batch_mean_t

        # 对每次得到的结果进行上下限的限制
        batch_x_t=torch.clamp(batch_x_t,min=-1,max=1)
        # 添加每步的去噪结果
        steps.append(batch_x_t)
    return steps

Part3 测试

# 开始测试
if __name__=='__main__':
    # 加载模型
    model = torch.load('model.pt')

    # 生成噪音图
    batch_size = 2
    batch_x_t = torch.randn(size=(batch_size, 1, IMAGE_SIZE, IMAGE_SIZE))  # (2,1,48,48)
    # 逐步去噪得到原图
    steps = backward_denoise(model, batch_x_t)
    # 绘制数量
    num_imgs = 20
    # 绘制还原过程
    plt.figure(figsize=(15, 15))
    for b in range(batch_size):
        for i in range(0, num_imgs):
            idx = int(T / num_imgs) * (i + 1)
            # 像素值还原到[0,1]
            final_img = (steps[idx][b] + 1) / 2
            # tensor转回PIL图
            final_img = TenosrtoPil_action(final_img)
            plt.subplot(batch_size, num_imgs, b * num_imgs + i + 1)
            plt.imshow(final_img)
    plt.show()

参考

视频讲解:diffusion推理_哔哩哔哩_bilibili

原理博客:手撕Diffusion系列 - 第一期 - DDPM原理-CSDN博客

你可能感兴趣的:(手撕Diffusion,深度学习,人工智能,算法,面试,学习)