ByteDance/DreamO:字节开源定制图像生成模型框架

1 引言

本文介绍了 DreamO,这是一个统一的图像定制框架,旨在支持多种任务并促进不同类型条件的无缝集成。DreamO 基于扩散变换器(DiT)框架,通过统一处理不同类型输入来实现图像定制。在训练过程中,构建了大规模训练数据集,涵盖各种定制任务,并引入特征路由约束以精准查询参考图像中的相关信息。此外,还设计了占位符策略,用于控制生成结果中条件的放置位置。通过三个阶段的渐进式训练策略(初始阶段、全面训练阶段和质量对齐阶段),使模型逐步获得鲁棒和泛化的图像定制能力。

2 相关工作

2.1 扩散模型

扩散模型作为一种强大的生成范式,因其高质量的生成和稳定性能而被广泛应用于图像生成领域。早期工作多使用 UNet 架构作为扩散模型的骨干网络,近年来,DiT 架构凭借其可扩展性成为更优选择。

2.2 可控图像生成

可控图像生成领域中,文本是最基本的条件输入。为实现精确的空间控制,一些方法在预训练扩散模型的基础上添加控制模块。近期基于 DiT 的扩散模型进展推动了可控图像生成的发展,例如通过连接所有输入令牌并使用任务特定数据集训练 LoRA 模块来实现各种应用。

2.3 扩散模型中的交叉注意力路由

现有研究表明,文本 - 视觉交叉注意力图 inherent 建立了语言令牌和视觉生成之间的空间 - 语义对应关系。基于此观察,本文在 DiT 架构中设计了路由约束,以实现一般图像定制任务。

3 方法

3.1 基础知识

DiT 模型使用变换器作为去噪网络来优化扩散潜在变量。输入的 2D 图像潜在变量被 patchify 成 1D 令牌序列,图像和文本令牌被拼接并统一处理。

3.2 概述

DreamO 以 Flux-1.0-dev 作为基础模型,构建统一的图像定制框架。给定 n 个条件图像,首先使用 Flux 的 VAE 编码条件图像到同一潜在空间,然后将 2D 潜在表示转换为 1D 令牌序列,并馈送到 Flux 中。引入专用映射层编码条件图像令牌,并添加可学习的条件嵌入和索引嵌入以区分条件潜在变量和噪声图像潜在变量。此外,集成 LoRA 模块作为可训练参数,使模型更好地适应条件任务。

3.3 路由约束

受 UniPortrait 和 AnyStory 启发,本文在 DiT 架构中设计路由约束。在条件引导框架中,条件图像和生成结果之间存在交叉注意力。通过平均密集相似性矩阵获得条件图像在生成输出不同位置的全局响应。使用 MSE 损失优化 DiT 中条件图像和生成结果之间的注意力。此外,还设计占位符 - 图像路由约束,以建立文本描述和条件输入之间的对应关系。

3.4 训练数据构建

为实现通用图像定制,收集涵盖广泛任务的训练数据集,包括身份配对数据、主题驱动数据、试穿数据和风格驱动数据等。

3.5 渐进式训练过程

直接在所有数据上训练会使收敛变得困难。因此,设计渐进式训练策略,首先在主题驱动训练数据上优化模型,然后纳入所有训练数据进行全面调整,最后进行图像质量细化训练阶段,使生成质量与 Flux 的生成先验对齐。

4 实验

4.1 实现细节

采用 Flux-1.0-dev 作为基础模型,LoRA 的秩设置为 128,总参数增加 707M。使用 Adam 优化器,学习率为 4e-5,在 8 个 NVIDIA A100 80G GPU 上训练。训练分为三个阶段,分别进行 20K、90K 和 3K 次迭代。

4.2 模型能力

DreamO 能处理多种图像定制任务,包括身份驱动、主题驱动、试穿和风格驱动等。实验结果表明,DreamO 能够在生成图像中保留特定身份、融合多个人特征,并通过文本输入精确控制其他属性和场景细节。

4.3 消融研究

路由约束能提高生成保真度并促进多条件解耦。渐进式训练策略有助于模型在复杂数据分布下更好地收敛,并纠正训练数据分布对生成质量的影响。

5 结论

DreamO 是一个统一的图像定制框架,能够在单个预训练 DiT 框架内处理各种条件类型。通过构建大规模训练数据集,并引入特征路由约束,DreamO 确保高保真一致性并有效解耦异构控制信号。渐进式训练策略使模型逐渐获得多样化控制能力,同时保持基础模型的图像质量。DreamO 的轻量级 LoRA 设计允许以低计算成本进行高效部署。
ByteDance/DreamO:字节开源定制图像生成模型框架_第1张图片

你可能感兴趣的:(ByteDance/DreamO:字节开源定制图像生成模型框架)