UNet是一种经典的卷积神经网络架构,解决了传统方法在数据量不足时面临的挑战。最初由医学图像分割任务提出,后被广泛应用于扩散模型(如DDPM、DDIM、Stable Diffusion)中作为噪声预测的核心网络。
核心结构包括一个收缩路径(downsampling path)和一个对称的扩展路径(upsampling path)。收缩路径通过多次下采样操作捕获上下文信息,而扩展路径则通过上采样操作结合底层特征和高层特征,实现精确的像素级分割。这种U形结构设计使其能够高效地利用有限的标注样本,并在现代GPU上快速执行。
以下是UNet的详细构成及其在扩散模型中的改进设计:
UNet呈对称的U型结构,包含:
编码器(下采样路径):逐步提取高层语义特征,降低分辨率。
解码器(上采样路径):逐步恢复空间分辨率,结合编码器特征进行精确定位。
跳跃连接(Skip Connections):连接编码器与解码器对应层,保留细节信息。
层级 | 操作 | 作用 |
---|---|---|
编码器层 | 卷积 → 激活(如ReLU) → 池化/步长卷积 | 提取特征,压缩空间维度 |
解码器层 | 反卷积/插值上采样 → 跳跃连接 → 卷积 | 恢复分辨率,融合低级与高级特征 |
瓶颈层 | 深层卷积操作 | 捕获全局上下文信息 |
在扩散模型(如DDPM、DDIM)中,UNet经过以下关键改进:
作用:将扩散过程的时间步信息注入网络,指导噪声预测。
实现:
时间步编码为向量(通过正弦位置编码或MLP)。
通过相加或拼接融入各层特征图。
结构:每个块包含多个卷积层 + 归一化层(如GroupNorm) + 激活函数。
改进:
引入残差连接,缓解梯度消失。
集成时间步嵌入(通过相加或自适应归一化)。
位置:通常在瓶颈层或解码器中插入。
类型:
自注意力(Self-Attention):捕捉长程依赖。
交叉注意力(Cross-Attention):用于多模态模型(如Stable Diffusion中结合文本提示)。
替代方案:相比BatchNorm,更适合小批量训练,提升稳定性。
跳跃连接增强:通过通道注意力(如Squeeze-and-Excitation)动态加权特征。
以Stable Diffusion为例,其UNet结构参数如下:
- 输入:带噪声的潜空间特征图(64×64×4)
- 编码器:4个下采样层,每层包含2个残差块
- 瓶颈层:2个残差块 + 自注意力层
- 解码器:4个上采样层,每层包含2个残差块
- 跳跃连接:逐层传递编码器特征至解码器
- 总参数量:约860M(Stable Diffusion 1.4版本)
分辨率保持:通过跳跃连接保留低级细节,避免上采样时的模糊问题。
动态条件注入:时间步嵌入和文本条件(如CLIP embedding)通过自适应归一化(AdaGN)融入网络:
# 示例:自适应归一化(AdaGN)
def adaptive_group_norm(x, timestep_emb, scale_shift=True):
scale, shift = timestep_emb.chunk(2, dim=1)
x = GroupNorm(x)
if scale_shift:
x = x * (1 + scale) + shift
return x
轻量化设计
使用深度可分离卷积(Depthwise Separable Conv)减少计算量。
通道数动态调整(如Stable Diffusion中通道数从128到1024递增)。
特性 | 传统UNet(医学分割) | 扩散模型UNet |
---|---|---|
输入/输出 | 原始图像 → 分割掩码 | 噪声图像 + 时间步 → 噪声残差 |
归一化方式 | BatchNorm | GroupNorm |
条件注入 | 无 | 时间步嵌入 + 文本/图像条件 |
注意力机制 | 无 | 自注意力/交叉注意力 |
参数量级 | 较小(几M~几十M) | 较大(几百M~上B) |
class DiffusionUNet(nn.Module):
def __init__(self):
super().__init__()
# 编码器
self.encoder = nn.ModuleList([
DownBlock(3, 64), # 下采样块
DownBlock(64, 128),
DownBlock(128, 256)
])
# 瓶颈层(含注意力)
self.bottleneck = nn.Sequential(
ResBlock(256, 512),
SelfAttention(512),
ResBlock(512, 512)
)
# 解码器
self.decoder = nn.ModuleList([
UpBlock(512, 256), # 上采样块(含跳跃连接)
UpBlock(256, 128),
UpBlock(128, 64)
])
# 时间步嵌入
self.time_embed = nn.Sequential(
nn.Linear(128, 256),
nn.SiLU(),
nn.Linear(256, 256)
)
def forward(self, x, t):
t_emb = self.time_embed(t) # 时间步嵌入
skips = []
for down in self.encoder:
x = down(x, t_emb) # 注入时间条件
skips.append(x)
x = self.bottleneck(x)
for up in self.decoder:
x = up(x, skips.pop(), t_emb)
return x
在扩散模型(如DDPM)中,改进的UNet结构通过以下方式整合时间嵌入(time embedding),实现对噪声的精准预测:
编码器-解码器架构:UNet由对称的下采样(编码器)和上采样(解码器)路径组成,通过跳跃连接保留多尺度特征。
残差块(ResBlock):每个下采样和上采样阶段包含多个残差块,用于特征提取。
时间步编码:扩散过程中的时间步 t 被编码为高维向量(如通过正弦函数或MLP),表示当前加噪阶段。
动态调节网络:Time embedding作为条件信号,影响每一层的计算,使网络适应不同时间步的噪声分布。
特征图加法:Time embedding通过线性层投影后,直接添加到残差块的特征图中
自适应归一化(AdaGN):
输入:加噪图像在时间步t时由原始图像逐步添加高斯噪声得到。
输出:预测的噪声ϵθ (xt,t),目标是最小化与真实噪声的误差
条件化每一层:每个残差块均接收相同的time embedding,确保网络在不同时间步采用不同的特征变换策略。
时间感知的噪声预测:通过时间嵌入,模型能区分早期(大尺度噪声)和晚期(细节噪声)的去噪需求,提升生成质量。
前向过程:逐步为图像添加噪声
反向过程:UNet预测噪声,通过迭代去噪重建