读论文笔记-Flamingo:少样本视觉语言模型

读论文笔记-Flamingo:少样本视觉语言模型

Plomblems

本文拟解决多模态机器学习中,如何将训练好的模型快速适应到少量标注数据的新任务中的问题。

Motivations

  1. 已有的VLM虽然能在zero-shot的场景下适应于新任务,但他们只解决了有限的使用情况(如CLIP只解决了图片分类),由于主要缺乏生成语言的能力其不能应用于开放性任务。
  2. 其他的一些方法虽然研究了基于视觉的语言生成但在数据量少的场景下并未展现出好的表现。
  3. 现有的受BERT启发的一系列VLM需要在新任务上微调,才能进行快速适应:在微调阶段又会有很大的计算量,并且运行时间慢。

Methods读论文笔记-Flamingo:少样本视觉语言模型_第1张图片

Overview:Flamingo是第一个用LM接收任意数量的图文或视频文本穿插数据作为输入,输出无格式文本的视觉-语言模型。模型主要包含冻结的视觉和语言模型并用Receiver Sampler有效地桥接他们。
Vision Encoder是预训练且冻结的NFNet,在预训练阶段Vision Encoder和BERT模型,使成对的embeddings的相似度最大,反之最小。使用双向对比损失函数作为优化目标得到训练好的vision encoder权重。损失函数见下:
L c o n t r a s t i v e : t x t 2 i m g = − 1 N ∑ i N l o g ( e x p ( L i T V i β ) ∑ j N e x p ( L i T L j β ) ) L_{contrastive:txt2img}=-\frac{1}{N}\sum^N_ilog(\frac{exp(L_i^TV_i\beta)}{\sum^{N}_{j}exp(L_i^TL_j\beta)}) Lcontrastive:txt2img=N1iNlog(jNexp(LiTLjβ)exp(LiTViβ))
L c o n t r a s t i v e : i m g 2 t x t = − 1 N ∑ i N l o g ( e x p ( V i T L i β ) ∑ j N e x p ( V i T L j β ) ) L_{contrastive:img2txt}=-\frac{1}{N}\sum^N_ilog(\frac{exp(V_i^TL_i\beta)}{\sum^{N}_{j}exp(V_i^TL_j\beta)}) Lcontrastive:img2txt=N1iNlog(jNexp(ViTLjβ)exp(ViTLiβ))
V i V_i Vi, T i T_i Ti分别是每个batch中的第i个元素的正则化后的embedding。
而对于LM是Chinchilla模型,实际上就是transformer decoder模型,是直接使用的预训练好的模型权重,没有重新从头训练。

  1. Receiver Resampler:从VE中接收时空特征,输出一个固定长度的视觉token。把任意数量的视觉特征变成固定数量的视觉输出,这个具体数量由queries的数量决定,其也是一个transformer结构,kv是展平(并添加temporal embeddings)的visual features和queries的拼接,queries是可学习的。相关伪码见下:

    def perceiver_resampler( x_f, 
    # The [T, S, d] visual features (T=time, S=space) 
    time_embeddings, # The [T, 1, d] time pos embeddings.
    x, # R learned latents of shape [R, d] 
    num_layers# Number of layers
    ):
    """The Perceiver Resampler model.""" 
    # Add the time position embeddings and flatten. 
    x_f = x_f + time_embeddings 
    x_f = flatten(x_f) # [T, S, d] -> [T * S, d] 
    # Apply the Perceiver Resampler layers. 
    for i in range(num_layers):
    	# Attention. 
    	x = x + attention_i(q=x, kv=concat([x_f, x])) 
     	# Feed forward. 
     	x = x + ffw_i(x) 
    return x
    

    在注意力机制中,计算注意力矩阵 Q ∗ K T Q*K^T QKT得到的形状是[R,T*S+R],最终的输出是[R,D]。结构图见下:读论文笔记-Flamingo:少样本视觉语言模型_第2张图片

  2. 通过在pretrained LM中穿插的新初始化的gated cross-attention将Visued tokens作为一个条件输入:
    1)这个gated也就是门控机制的意思,是参考LSTM的tanh门控机制,gated值是可训练的参数用于控制添加视觉prompt的程度,在一开始是0(为了尽可能贴近预训练LLM的输入)随着训练层数和训练时间的增加数值逐渐加大。但是需要注意的是并不能说随着层数增加,tanh值增加,学到了更多的视觉特征。 这是因为在gated前面由cross-attention,其中的激活函数可能提取了很多或很少的视觉特征,但不能衡量其提取的程度与tanh提取的程度。
    读论文笔记-Flamingo:少样本视觉语言模型_第3张图片 2)cross-attention是具有掩码的注意力机制,对于每个text token只与前面最近的image/video交互,这是通过函数φ映射可用图像实现的,整个结构能够实现LM模型成功输入多张图片。整体结构的伪码见下:
    读论文笔记-Flamingo:少样本视觉语言模型_第4张图片
    详细结构图见下:
    读论文笔记-Flamingo:少样本视觉语言模型_第5张图片
    其预测输出的概率表达式可以这样表示:
    p ( y ∣ x ) = ∏ l = 1 L p ( y l ∣ y < l , x < = l ) p(y|x)=\prod_{l=1}^{L}p(y_l|yp(yx)=l=1Lp(yly<l,x<=l)
    x<=l表示的就是l前面的图像
    具体来说,序列Y输入LM中,原本图像使用 < i m a g e > <image>替代,用 < B O S > <BOS>表示序列开头, < E O C > <EOC>表示一个文本图像对或文本块结束了。在Y的基础上添加φ生成的掩码,再和视觉的KV做attention。φ是提前把L长的序列映射到离散的图像数量0,1,2…上,数学公式为 ϕ : [ 1 , L ] ↦ [ 0 , N ] \phi:[1,L]\mapsto[0,N] ϕ:[1,L][0,N],那么上面的x<=l就可以等价为 x i ∣ i < = ϕ ( l ) {x_i|i<=\phi(l)} xii<=ϕ(l)表示最近图像的下标。

  3. 使用由3种数据集混合而成的数据同时训练:
    1)构造M3W图文交叉数据集,少样本学习能力极度依赖图文交叉数据训练。从网站上选取43million个网页根据对象文本结构DOM确定图文位置,每个文档中随机采样一个256tokens的序列和序列中的前5个图片并用代替图像。图在文之前还是之后,是通过 p n e x t p_{next} pnext概率决定的,可以视作一种数据增强方式。
    2)文本图像对、视频文本对数据集。首先是使用了align的数据集(用替代文本来描述图像),并自行设计vtp(27million)和ltip(312million)数据集作为补充,并把数据集的语法向M3W对齐。

Experiments

模型训练的目标和优化策略:训练目标是每个数据集的负对数可能性的权重和的期望,并在所有数据集上积累梯度而不是采用"round-robin"范式,因为有更好的训练效果。损失函数见下:
∑ m = 1 M λ m ∗ E ( x , y ) ∼ D m [ − ∑ L l = 1 l o g p ( y l ∣ y < l , x = l ) ] \sum^{M}_{m=1}\lambda_m*E_{(x,y)\sim{D_m}}[-\sum^{L}{l=1}logp(y_l|ym=1MλmE(x,y)Dm[Ll=1logp(yly<l,x=l)]
Dev data+Test data:用COCO\OKVQA\VQAv2\MSVDQA和VATEX作为dev set(也就是验证集)用于调整并得到最优参数的模型,为了防止用同样的数据集作为测试集有偏差,所以使用剩下的11个数据集进行性能验证。对于dev基准,使用了validation query、validation support、test qury、test support,而test只用了后面两种。

  1. 验证少样本下的性能表现
    • 验证模型的上下文学习能力:使用图文交叉数据或视频文本交叉数据作为support,添加一个query visual input作为prompt,support和query是默认随机顺序拼接的,但当support数量很多时就时直接检索方式构建prompt
    • 验证模型在开放式问题下的能力:使用束搜索进行解码,束的大小是3。
    • 验证模型在封闭式问题下的能力:用log-likelihood对每个可能选项评分并降序排列。
    • 验证零样本泛化能力:只用2个只有文本的例子作为prompt,例如:“ < B O S > <BOS>Output: This is a cat wearing sunglasses.Output: Three elephants walking in the savanna. < E O C > <EOC> < i m a g e > <image>Output:”没有给模型输入support图片。如果只输入一个text的话,生成的输出会和输入很像,输入2个以上的话也只有非常有限的提升。
      结论:本模型相比与微调的SOTA在7个数据集上表现出了更好的能力,相比于few-shot/zero-shot的SOTA在14个任务上表现出了更好的性能,且展现出了模型的灵活性。
  2. 采用’pretrain-finetuning’范式验证模型的性能:比之前在few-shot场景下展示出的性能还好,且与现有的微调范式的SOTA相比达到了5个任务上的最优。
  3. 消融实验:验证混合训练数据的重要性,除去M3W数据集会导致性能大幅下降,表明数据多样性很重要。验证视觉条件对frozen LM的重要性,不用visual会导致性能和训练稳定性受损;验证了穿插XATTN的数量和使用不同resampler的情况;验证了使用强大vision backbone的重要性;冻结LM能够组织灾难性遗忘。

Limitations

  1. 直接继承了LM的缺点:LM的先验知识虽然通常是有用的,但会导致幻觉和没有根据的猜测。难以泛化到比训练序列更长的序列上。受困于低质量样本的训练效率。
  2. 分类能力是比现在SOTA的对比策略的模型差的(CLIP等),这可能是由于所对比的模型都以图文检索为目标进行训练的,而分类可以看作检索的一种特例。之后的研究可以探索嫁接这两类不同的模型的技术。
  3. 上下文学习的方法虽然比基于梯度积累的少样本学习方法好,但根据实际应用也会有缺陷。如推理计算损失更大、对prompt很敏感、当shots数量大于32时模型的性能不再有明显上升(可能是由于模型仅通过少样本就能定位到要探究的是什么任务),这种方法在样本数量增加时性能提升有限,可以研究不同少样本方法的互补。
  4. NLP接口在某些视觉任务上的局限性,无法预测结构化、密集型任务、连续空间预测,可以考虑增添额外模态。

你可能感兴趣的:(论文阅读,语言模型,人工智能)