论文链接:Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction
本文提出的视觉自回归建模/VAR这种新范式,其将图像的自回归学习重新定义为从粗到细的“下一个尺度预测”或“下一个分辨率预测”,与常规的LLM预测下一个token的范式不同。VAR首次基于GPT架构的AR模型在图片生成方面超过了扩散模型,在 ImageNet 256×256 基准测试中,FID、IS分数均大幅提高,推理速度也快了将近20倍。实验证实,VAR在图像质量、推理速度、数据效率和可扩展性等多个维度由于DiT,且其具有明显的Scaling规律,在图像修复、外绘、编辑等下游任务中具有较好的泛化能力。
图1中对比展示了三种自回归生成模型,(a)中是常规的基于预测下一个token范式的自回归模型/AR,GPT、LLaMa等LLM模型均是此范式;(b)中是与(a)一样,以光栅扫描将连续图片转换为离散tokens,通过预测下一个token进行图片生成的自回归模型;©中是文本提出的VAR,即从粗到细的多尺度自回归预测模型,可以看出,随着预测,图片尺度变大,内容由粗糙到精细。
VAR借鉴了人类通常以分层方式感知或创建图像的本能,即先捕捉全局信息,再处理局部细节;这种多尺度、从粗到细的顺序过程正好与自回归建模需要定义数据顺序相对应,故启发了开发人员即将图像的自回归学习定义为图1©中的“预测下一个尺度”,而不是传统的“预测下一个token”的范式。VAR中,先将图像编码为多尺度token maps,从1×1的token map开始自回归过程,逐渐扩展分辨率,即每一步Transformer基于先前所有的token maps预测下一个更高分辨率的token map。本文贡献如下:
此范式需要进行类似的分词操作,即先将连续的2D图片数据分割为离散的tokens,然后将其构建为1D的tokens序列。离散这一步往往是通过训练一个VQVAE模型实现,由编码器 $ E \mathcal{E} E、量化器 Q \mathcal{Q} Q、解码器 D \mathcal{D} D和码本器 Z ∈ R V × C \mathcal{Z} \in \mathbb{R}^{V \times C} Z∈RV×C组成。编码器将图片转换为固定尺寸的特征图 f ∈ R h × w × C f \in \mathbb{R}^{h \times w \times C} f∈Rh×w×C,量化器将 f f f与码本中的向量进行相似度对比装起转换为离散的tokens q ∈ [ V ] h × w q \in [V]^{h \times w} q∈[V]h×w。此时 q q q仍是二维网格排列,还不是一维序列。图像tokens顺序必须为单向自回归学习显示定义,常规的自回归方法使用光栅扫描、螺旋扫描或Z曲线排序等策略将 q q q展平为一维序列 x = ( x 1 , ⋅ ⋅ ⋅ , x h × w ) x = (x_1,\cdot\cdot\cdot,x_{h \times w}) x=(x1,⋅⋅⋅,xh×w),然后通过预测写一个token进行自回归训练。此种图片生成模型有以下不足:
VAR架构下,自回归单元是token map。VAR先将图像特征图 f ∈ R h × w × C f \in \mathbb{R}^{h \times w \times C} f∈Rh×w×C量化为K个多尺度token映射 ( r 1 , r 2 , ⋅ ⋅ ⋅ , r K ) (r_1,r_2,\cdot\cdot\cdot,r_K) (r1,r2,⋅⋅⋅,rK),每个映射的分辨率 h k × w k h_k \times w_k hk×wk递增,最终 r K r_K rK的分辨率与原始特征图 h × w h \times w h×w匹配。自回归似然公式为:
p ( r 1 , r 2 , ⋅ ⋅ ⋅ , r K ) = ∏ k = 1 K p ( r k ∣ r 1 , r 2 , ⋅ ⋅ ⋅ , r k − 1 ) (1) p(r_1,r_2,\cdot\cdot\cdot,r_K)= \prod_{k=1}^{K} p(r_k|r1,r_2,\cdot\cdot\cdot,r_{k-1}) \tag1 p(r1,r2,⋅⋅⋅,rK)=k=1∏Kp(rk∣r1,r2,⋅⋅⋅,rk−1)(1)
其中每个自回归单元 r k ∈ [ V ] h k × w k r_k \in [V]^{h_k \times w_k} rk∈[V]hk×wk是尺度 k k k的token map,其包含 h k × w k h_k \times w_k hk×wk个tokens。序列 ( r 1 , r 2 , ⋅ ⋅ ⋅ , r k − 1 ) (r_1,r_2,\cdot\cdot\cdot,r_{k-1}) (r1,r2,⋅⋅⋅,rk−1)称为 r k r_k rk的前缀,在第 k k k次自回归步骤中, r k r_k rk中所有的 h k × w k h_k \times w_k hk×wk个tokens基于 r k r_k rk的前缀和对应的第 k k k个位置嵌入map并行生成。
图2展示VAR的两阶段训练过程。第一阶段在图片上训练多尺度VQVAE,其可以将图片转换为 k k k个token maps,为第二阶段训练提供Ground Truth。第二阶段通过“下一尺寸预测”训练Transformer,其以 ( [ s ] , r 1 , r 2 , ⋅ ⋅ ⋅ , r K − 1 ) ([s],r_1,r_2,\cdot\cdot\cdot,r_{K-1}) ([s],r1,r2,⋅⋅⋅,rK−1)作为输入预测 ( r 1 , r 2 , r 3 ⋅ ⋅ ⋅ , r K ) (r_1,r_2,r_3\cdot\cdot\cdot,r_K) (r1,r2,r3⋅⋅⋅,rK)。训练过程会使用块级因果注意力掩码确保 r k r_k rk只与 r ≤ k r_{\leq k} r≤k进行注意力计算,使用标准交叉熵损失训练。推理阶段可使用KV-Caching,无需使用掩码。VAR解决了上述讨论的传统AR模型的三个问题:
采用与VQGAN相同架构,修改了多尺度量化层。对特征图 f f f或重建图 f ^ \hat{f} f^采用残差设计的编码和解码流程详见算法1和算法2。实验发现,这种残差式设计比独立插值方法表现更优。算法 1 表明,每个 r k r_k rk仅依赖于其前缀序列 ( r 1 , r 2 , ⋅ ⋅ ⋅ , r k − 1 ) (r_1,r_2,\cdot\cdot\cdot,r_{k-1}) (r1,r2,⋅⋅⋅,rk−1)。请注意,所有尺度共享一个码本 $Z
$,以确保每个 r k r_k rk的token属于同一词表 [ V ] [V] [V]。为解决将 z k z_k zk上采样至 h K × w K h_K \times w_K hK×wK时的信息损失问题,使用了 k k k个额外的卷积层 { ϕ k } k = 1 K \{\phi_k\}_{k=1}^K {ϕk}k=1K。将 f f f下采样至 h k × w k h_k \times w_k hk×wk后不使用卷积操作。
VAR中的分词器采用原始VQVAE架构,附加 k k k个额外的卷积层的量化方案,额外增加0.03M参数。所有尺度共享一个码本,尺寸大小 V = 4096 V=4096 V=4096。使用常规的VQVAE复合损失在OpenImages上训练,空间下采样率为16。
transformer部分采用类似于GPT-2和VQGAN标准的仅解码器Transformer架构,集成自适应归一化层/AdaLN。对于类别条件生成任务,将类别嵌入作为起始token [ s ] [s] [s],同时作为AdaLN的条件输入。实现发现,在注意力计算前将查询/queries和键/keys归一化为单位向量可稳定训练过程。模型未采用RoPE、SwiGLU或RMS Norm等LLM中常用技术。模型宽度 w w w,注意力头数 h h h,丢弃率 d r dr dr随深度 d d d线性缩放,具体如下所示:
w = 64 d , h = d , d r = 0.1 ⋅ d 24 (2) w=64d,\quad h=d,\quad dr=0.1\cdot\frac{d}{24} \tag2 w=64d,h=d,dr=0.1⋅24d(2)
因此,深度为 d d d的VAR transformer模块的主要参数数量 N N N如下计算:
N ( d ) = d ⋅ 4 w 2 ⏟ self-attention + d ⋅ 8 w 2 ⏟ feed-forward + d ⋅ 6 w 2 ⏟ adaptive layernorm = 18 d w 2 = 73728 d 3 (3) N(d) = \underbrace{d \cdot 4w^2}_{\text{self-attention}} + \underbrace{d \cdot 8w^2}_{\text{feed-forward}} + \underbrace{d \cdot 6w^2}_{\text{adaptive layernorm}} = 18dw^2 = 73728d^3 \tag3 N(d)=self-attention d⋅4w2+feed-forward d⋅8w2+adaptive layernorm d⋅6w2=18dw2=73728d3(3)
所有模型均采用相似的设置进行训练:每个batch size为 256 的批次的基础学习率为 10 − 4 10^{-4} 10−4;使用 AdamW 优化器,参数 β 1 = 0.9 , β 2 = 0.95 , decay = 0.05 \beta_1=0.9,\beta_2=0.95,\text{decay}=0.05 β1=0.9,β2=0.95,decay=0.05;取决于模型规模,批次大小为 768 至 1024,训练轮次为 200 至 350。评估结果表明,这种简单的模型设计具备良好的可扩展性和泛化能力。
LLM模型中的幂律缩放定律已得到证实并广泛应用,该定律可以通过小模型的性能指导训练大模型,便于合理分配计算资源,VAR也具有明显的幂律缩放定律。本文在ImageNet上训练了12种不同规模的模型,基于公式(3)可知模型参数量主要与深度 d d d相关,将深度从6递增至30,参数量从1800万到20万亿不等;每轮训练包含128万张或VQVAE转化后对应8700亿个图像tokens。对于不同规模的模型,训练轮次为200~350轮不等。基于上述实验,可针对在tokens数量 T T T充足的情况下,模型参数 N N N与最优训练计算量 C m i n C_{min} Cmin的缩放定律。
在包含5万张图像的ImageNet验证机上进行评估,主要计算交叉熵损失 L L L,token预测错误率 E r r Err Err,结果如图4所示。图4表示了VAR的缩放定律,其中包含幂律拟合曲线(虚线)和图例中的公式。接近零的小指数 α \alpha α表明,当扩大 VAR Transformer 规模时,测试损失 L L L和token错误率 E r r Err Err均呈现平滑下降趋势。坐标轴均采用对数刻度。皮尔逊相关系数接近−0.998,表明 log ( N ) \log(N) log(N)与 log ( L ) \log(L) log(L)或 log ( N ) \log(N) log(N)与 log ( E r r ) \log(Err) log(Err)之间存在强线性关系。
从图4中可以观察到测试损失 L L L随参数量 N N N变化呈现明显的幂律缩放趋势,该关系可表示为:
L l a s t = ( 2.0 ⋅ N ) − 0.23 and L a v g = ( 2.5 ⋅ N ) − 0.20 (4) L_{last}=(2.0 \cdot N)^{-0.23} \quad \text{and} \quad L_{avg}=(2.5 \cdot N)^{-0.20} \tag4 Llast=(2.0⋅N)−0.23andLavg=(2.5⋅N)−0.20(4)
尽管缩放定律主要围绕测试损失展开研究,但本文通过实验也观察到token错误率 E r r Err Err呈现类似的幂律趋势:
E r r l a s t = ( 4.9 ⋅ 10 2 N ) − 0.016 and E r r a v g = ( 6.5 ⋅ 10 2 N ) − 0.010 (5) Err_{last}=(4.9 \cdot 10^2N)^{-0.016} \quad \text{and} \quad Err_{avg}=(6.5 \cdot 10^2N)^{-0.010} \tag5 Errlast=(4.9⋅102N)−0.016andErravg=(6.5⋅102N)−0.010(5)
这些结果验证了 VAR 强大的可扩展性,通过扩大 VAR Transformer 的规模能够持续提升模型的测试性能。
对于12个模型中的每一个,还记录了其测试损失 L L L和token错误率 E r r Err Err与计算量 C C C的关系,计算量的单位为PFlops,即每秒 10 15 10^{15} 1015次浮点运算,结果如图5所示。图5中绘制了 L L L和 E r r Err Err的帕累托前沿,以突出达到特定损失或错误率所需的最优训练计算量 C m i n C_{min} Cmin。经拟合得到的 L L L 和 E r r Err Err与 C m i n C_{min} Cmin的幂律缩放定律如下:
L last = ( 2.2 ⋅ 10 − 5 C min ) − 0.13 (6) L_{\text{last}} = \left(2.2 \cdot 10^{-5} C_{\text{min}}\right)^{-0.13} \quad \tag6 Llast=(2.2⋅10−5Cmin)−0.13(6)
L avg = ( 1.5 ⋅ 10 − 5 C min ) − 0.16 (7) L_{\text{avg}} = \left(1.5 \cdot 10^{-5} C_{\text{min}}\right)^{-0.16} \tag7 Lavg=(1.5⋅10−5Cmin)−0.16(7)
E r r last = ( 8.1 ⋅ 10 − 2 C min ) − 0.0067 (8) Err_{\text{last}} = \left(8.1 \cdot 10^{-2} C_{\text{min}}\right)^{-0.0067} \tag8 Errlast=(8.1⋅10−2Cmin)−0.0067(8)
E r r avg = ( 4.4 ⋅ 10 − 2 C min ) − 0.011 (9) Err_{\text{avg}} = \left(4.4 \cdot 10^{-2} C_{\text{min}}\right)^{-0.011} \quad \tag9 Erravg=(4.4⋅10−2Cmin)−0.011(9)
上述公式在最优计算量 C m i n C_{min} Cmin的6个数量级范围内均成立,实验表明当使用足够数据训练时,更大的 VAR transformer 模型计算效率更高,因为它们可以用更少的计算量达到相同的性能水平。
图5中线条颜色表示不同模型规模,红色虚线为幂律拟合曲线(图例中包含公式)。坐标轴采用对数刻度,皮尔逊相关系数接近 0.99,表明 log ( C m i n ) \log(C_{min}) log(Cmin)与 log ( L ) \log(L) log(L)或 log ( C m i n ) \log(C_{min}) log(Cmin)与 log ( E r r ) \log(Err) log(Err)之间存在强线性关系。
在图像修复与扩展任务中,对 VAR-d30 进行测试,通过在掩码外部使用教师强制真实token、仅让模型在掩码内部生成token(且未注入任何类别标签信息),其在不修改网络架构或调整参数的情况下取得了不错效果;同时,遵循 MaskGIT 在类别条件图像编辑任务上的测试显示,模型被强制仅在边界框内基于类别标签生成token时,能产生与周围上下文良好融合的合理内容。这两项任务的结果均证实了 VAR 在不同下游任务中无需特殊调整即可实现有效泛化的能力。图6展示了具体图片结果,表明VAR 无需特殊设计和微调即可泛化到新的下游任务。
视觉自回归建模(VAR)的是一种新型视觉生成框架,其特点在于:1)从理论上解决了标准图像自回归(AR)模型的一些固有问题;2)使基于语言模型的 AR 模型首次在图像质量、多样性、数据效率和推理速度方面超越了强大的扩散模型。当将 VAR 扩展到 20 亿参数时,其测试性能与模型参数或训练计算量之间存在明显的幂律关系,皮尔逊系数接近 - 0.998,这表明该框架可用于稳健的性能预测。这些缩放定律以及零样本任务泛化的可能性 —— 作为大型语言模型(LLMs)的标志 —— 现已在 VAR transformer 模型中得到初步验证。