温馨提示:
本篇文章已同步至"AI专题精讲" MATC:通过数学推理和图表还原增强视觉语言预训练
视觉语言数据,如图表、图形和信息图,在人类世界中无处不在。然而,现有的最先进的视觉语言模型在这些数据上的表现并不理想。我们提出了MATCHA(数学推理与图表去渲染预训练),旨在增强视觉语言模型在联合建模图表/图形与语言数据方面的能力。具体而言,我们提出了几个预训练任务,涵盖了图形解构和数值推理,这些是视觉语言建模中的关键能力。我们从最近提出的图像到文本视觉语言模型Pix2Struct开始进行MATCHA预训练。在标准基准测试如PlotQA和ChartQA上,MATCHA模型比现有的最先进方法提高了近20%的性能。我们还考察了MATCHA预训练如何在截屏、教科书图表和文档图形等领域迁移,并观察到整体性能提升,验证了MATCHA预训练在更广泛的视觉语言任务中的实用性。
视觉语言是使用紧密集成的文本和视觉元素传达意义的系统(Horn, 1998)。它在人类世界中无处不在,典型的例子包括图表、图形和图解,存在于教科书、科学论文、网页等地方。视觉语言也是高度复杂的——除了文本,它的结构单元还包括线条、形状、颜色、方向、尺度、角度、空间等。需要从这些结构单元中识别模式,并执行空间分组和/或对齐,以提取用于推理的信息。尽管视觉语言非常普遍且重要,但在机器学习社区中,关于视觉语言理解的研究相对较少。预训练于自然图像或从网络爬取的图像-文本对的视觉语言模型在视觉语言任务(如ChartQA (Masry et al., 2022) 和PlotQA (Methani et al., 2020))上的表现较差,这是由于联合建模语言和符号的复杂性(实验中有更多证据)。Pix2Struct (Lee et al., 2023) 是一种最近提出的视觉定位语言预训练策略,显著优于标准的视觉语言模型,并且超越了广泛的基于OCR的管道方法。Pix2Struct设计了一种新颖的网页截图解析任务,并为图像到文本编码解码器Transformer(Vaswani et al., 2017)设计了可变分辨率输入表示。在这项工作中,我们使用Pix2Struct作为基础模型,并进一步通过图表去渲染和数学推理任务进行预训练。
我们认为,视觉语言理解需要两个关键要素:(1) 布局理解(包括数字提取及其组织),(2) 数学推理。(1) 需要发现图像的潜在模式并将图像中的元素以逻辑形式组织。(2) 需要对从(1)中提取的元素进行操作,并推导出任务或查询所要求的有意义信息。基于这些观察,我们提出了两个互补的预训练任务来增强视觉语言理解:图表去渲染和数学推理。在图表去渲染任务中,给定一个图表,图像到文本模型需要生成其底层数据表或用于渲染的代码。第二个任务是数学推理预训练。我们选择了两个数值推理数据集MATH (Saxton et al., 2019) 和DROP (Dua et al., 2019),将输入渲染为图像,图像到文本模型需要解码答案。
我们使用一系列视觉语言任务来测试我们方法的有效性。 最重要的是,我们在ChartQA和PlotQA上进行测试,这些是关于图表和图形的问答数据集。在这两个数据集上,MATCHA甚至超越了假设能够访问图表底层数据表的最先进模型,并且在没有黄金数据表的情况下,比之前的最先进方法高出多达20%。我们还在图表到文本的总结任务上测试MATCHA,观察到它在Pix2Struct上有明显的改进,并在Chart-to-Text(Kantharaj et al., 2022)Pew拆分上达到了最先进的性能。最后,但同样重要的是,为了检验MATCHA预训练是否能迁移到标准图表和图形领域之外的数据集,我们还在四个额外的领域上测试了MATCHA,这些领域是Pix2Struct曾经评估过的:文档、插图、用户界面和自然图像(包括数据集,如教科书问答、控件说明等)。我们展示了与基础模型Pix2Struct相比,大多数附加数据集的持续改进。
总结,我们的贡献包括:(1) 提出了一组有效的预训练任务,用于视觉语言学习;(2) 展示了在所有评估任务中的一致性改进,并在ChartQA、PlotQA和Chart-to-Text总结(Statista集)上取得了最先进的结果,且未访问黄金数据表;(3) 验证了MATCHA预训练能迁移到视觉语言基准测试,超越图表和图形领域,并在图表领域以外的广泛数据集(如教科书VQA和控件说明)上实现了最先进的性能;(4) 进行了全面的消融实验和分析,以理解每个预训练组件的影响及其对下游任务性能的影响。
视觉语言研究与对视觉语言的关注缺乏。视觉与语言的研究主要集中在自然图像上。视觉基础推理数据集如NLVR2(Suhr et al., 2019)和MaRVL(Liu et al., 2021)大多属于自然图像领域。合成数据集如SHAPES(Andreas et al., 2016)、NLVR(Suhr et al., 2017)和CLEVR(Johnson et al., 2017)可以看作是视觉语言领域。然而,它们的视觉语言系统明显比现实世界中的系统要简单,例如图表和图形。因此,从这些合成数据集中提取信息是直截了当的。此外,合成数据集中的查询通常比较简单,不需要复杂的推理(例如,问题通常关于空间关系或物体计数)。因此,当前的视觉语言模型能够很好地处理上述提到的合成视觉推理数据集。然而,它们在真实世界的视觉语言数据集上表现较差,这些数据集的信息提取和推理变得更加复杂(我们将在第4节中展示这一点)。
基于OCR和端到端方法的视觉定位语言。 LayoutLM(Xu et al., 2020;Huang et al., 2022)利用了patch-OCR对齐损失,将外部OCR系统的知识注入Transformer模型中。PresSTU(Kil et al., 2022)和PaLI(Chen et al., 2023)也设计了OCR感知的预训练目标,模型需要预测从现成的OCR系统中获取的文本。ChartBERT(Akhtar et al., 2023)依赖OCR文本和位置来训练Transformer编码器。虽然OCR系统可以帮助准确提取文本,但运行它们的成本不低。此外,OCR系统并不涵盖那些不显式使用文本的视觉语言系统。例如,图表和图形并不总是显式写出数字。在我们并行的工作DEPLOT(Liu et al., 2023)中,我们探索了将图表到文本的翻译模块(没有OCR)与大型语言模型结合的方法。
Donut(Kim et al., 2022)、Dessurt(Davis et al., 2023)和Pix2Struct(Lee et al., 2023) 是用于视觉语言的端到端预训练模型,其中Donut和Dessurt专注于文档理解,Pix2Struct旨在为所有视觉语言任务提供通用的预训练检查点。MATCHA的架构与Pix2Struct相同——我们在Pix2Struct检查点的基础上,继续通过新的目标进行预训练。
通过设计新颖的预训练任务来学习推理。 MATCHA与设计更好的预训练目标以帮助语言模型(LMs)进行更好推理的文献相关,因为通过简单的语言建模目标(例如,掩蔽语言建模和自回归语言建模)很难培养这种能力。Geva et al.(2020);Eisenschlos et al.(2020)通过人工编写的模板生成额外的专注于(数值)推理的预训练数据。Pi et al.(2022)合成数据和程序,然后使用程序执行器模拟答案。语言模型预训练为根据数据和程序预测答案。Wu et al.(2022)探索了一系列合成预训练任务,并发现即使是注入简单的归纳和推理规则,也能教会语言模型进行推理。我们通过将图表映射到数据和代码,并直接学习文本数学推理数据集,来教导图像到文本模型进行推理。
我们认为,布局理解和基本数学运算能力是进行视觉语言理解/推理的关键要素。我们通过提出两个预训练任务——图表还原(§3.1)和数学推理(§3.2),将这些能力注入到模型中,下面将详细介绍这两个任务。
图表通常是由一个底层的数据表和一段代码生成的。代码决定了图形的整体布局(例如,图表类型、方向、颜色/形状方案),而底层的数据表则决定了实际的数字及其分组。数据和代码都会发送给编译器/渲染引擎,以生成最终的图像。为了理解图表,一个人需要在图像中发现视觉模式,有效地解析并将其分组,从而提取关键信息。逆向图表渲染过程需要具备所有这些能力,因此它可以作为一个完美的预训练任务。
实际上,获得图表、其底层数据表和渲染代码是非常具有挑战性的。为了收集足够的预训练数据,我们独立地积累了(图表,代码)对和(图表,表格)对。对于(图表,代码)对,我们爬取了所有符合许可证要求的GitHub IPython笔记本,并提取其中包含图形的代码块。每个图形及其前面的代码块会被保存为一个(图表,代码)对。
对于(图表,表格)对,我们探索了两种来源。第一种是手动编写代码,将Herzig等人(2020)从网络爬取的Wikipedia表格转换为图表。我们随机组合了多种绘图选项。关键的随机变量包括:使用matplotlib或seaborn作为绘图包;使用条形图、折线图或饼图;图表的样式和颜色;是否在图表上显式显示数字;文本的字体和大小。除了我们自己的合成数据外,我们还加入了Methani等人(2020)生成的图表-表格对(来自PlotQA),以丰富预训练语料库。第二个来源是从网站爬取的图表-表格对。像Statista这样的网页提供了这两者。我们直接使用了Masry等人(2022)爬取的图表-表格对(来自ChartQA),这些对包含来自四个网站的约2万对:Statista、Pew、Our World in Data和OECD。
需要注意的是,为了避免泄露测试数据,PlotQA和ChartQA任务使用的相同图表数据作为预训练数据,我们仅使用训练集中的图表-表格对进行预训练,严格排除测试图表/表格。在消融实验(§5.1)中,我们将展示来自两种来源的图表-表格对都是有用的,并且拥有多样化的图表-表格对总是能带来更好的效果。然而,仅使用我们的合成数据已经带来了显著的改进,表明图表还原的概念可以轻松地迁移到其他领域的图表(包括真实世界中的图表)。
进行视觉语言推理需要:
图表还原解决了第一个问题,但当前的预训练框架仍缺乏第二个能力。因此,我们提出通过学习文本数学数据集,显式地将数值推理知识注入到图像到文本模型中。我们使用了两个现有的文本数学推理数据集进行预训练:MATH(Saxton等,2019)和DROP(Dua等,2019)。
MATH是一个合成数据集,包含每个问题模块(类型)有两百万个训练样本(详见附录§A,列出了MATHA预训练中包括的模块)。DROP是一个阅读理解类型的问答数据集,其中输入是段落上下文和问题。DROP包含96k个问题-答案对,涵盖6.7K个段落。要解决DROP中的问题,模型需要阅读段落,提取相关数字并进行数值计算以预测答案。我们发现这两个数据集互补地有帮助。MATH包含大量分类问题,帮助我们识别需要显式注入模型的数学运算,而DROP的阅读理解格式类似于典型的问答任务,模型需要同时执行信息提取和推理。
在实践中,我们将这两个数据集的输入渲染为图像(对于DROP,将上下文和问题合并)。图像到文本模型被训练为解码图像给出的答案。MATH和DROP的示例可以在图1(浅红色)中找到。
除了两个新提出的预训练策略外,为了防止灾难性遗忘,我们还继续应用Pix2Struct(Lee等,2023)中的截图解析预训练。具体来说,给定一个网站的截图(其中网站的部分被遮罩),图像到文本的变压器需要预测能够呈现原始未遮罩网站截图的简化HTML代码。最终的预训练任务是所有前述任务的混合,我们将在§4.1中讨论混合权重。
我们在§4.1中详细介绍实验设置,在§4.2中介绍主要结果,在§4.3中介绍额外的Pix2Struct任务的结果。
预训练数据集/任务
总体而言,我们创建了一个预训练任务的混合,包含40%的数学推理、40%的图表还原和20%的截图解析。每个特定任务/数据集的权重如表1所示。对于图表还原,我们有四个数据来源:
最初,我们为四个任务分配了相等的权重,但我们注意到由于图表到代码任务非常困难(预训练数据噪声较大),训练不稳定。因此,我们将图表到代码的权重降低至4%,并将所有图表到表格的任务增加至12%。对于数学推理,我们为MATH和DROP分配了相等的权重(各为20%)。
有关预训练数据集消融实验的更多信息,请参见§5.1。
评估数据集
我们在多模态英语问答和生成任务上评估了MATCHA,包括ChartQA(Masry等,2022),PlotQA(Methani等,2020),以及图表到文本总结(Kantharaj等,2022)。ChartQA和PlotQA都是图表领域的问答数据集,其中输入是图表的图像和一个查询,目标是一个答案字符串。ChartQA有两个子集:(1)增强集和(2)人工集,其中增强集是机器生成的,因此更具提取性,而人工集是人工编写的,需要更复杂的推理。PlotQA也有两个版本:v1和v2。类似地,v1更侧重于提取性问题,而v2则需要更多的数学推理。然而,v1和v2都是机器生成的。图表到文本总结数据集也有两个子集,它们分别是“Pew”和“Statista”,这些名称描述了图像示例的来源。对于Pew,金标准摘要是自动从图像周围的区域提取的;对于Statista,摘要是人工编写的。各数据集的大小如表2所示。除了图表领域的数据集外,我们还在Pix2Struct(Lee等,2023)中使用的其他数据集上进行了评估。我们按照Pix2Struct的确切设置和协议重新运行Pix2Struct实验,但用MATCHA替换了初始检查点。更多实验细节请参见Lee等(2023)。
评估指标
对于ChartQA和PlotQA,我们遵循之前的工作(Masry等,2022;Methani等,2020;Lee等,2023),使用放宽的正确性(完全匹配,但允许最多5%的数值误差)。对于Chart-to-Text,我们使用BLEU4。对于所有Pix2Struct实验,我们使用Lee等(2023)中引入的相同指标。
训练和推理细节
我们每200步保存一次检查点,并保留产生最高验证分数的检查点。遵循Lee等(2023),我们在ChartQA的增强集和人工集上联合微调模型(即,两个集的一个检查点),并使用在人工验证集上选择的检查点作为测试的最终检查点。对于PlotQA和Chart-to-Text,我们为v1、v2、Pew和Statista集训练独立的模型。对于预训练,我们使用批量大小512和最大序列长度192,预训练100k步,最终的MATCHA检查点在第90k步选择(此时的平均完全匹配验证分数最高)。对于下游任务微调,我们使用批量大小256和最大序列长度128。对于ChartQA和Chart-to-Text,我们微调10k步,PlotQA微调20k步(因为它显著更大)。Pix2Struct任务的设置与原论文相同。对于PaLI基线,我们使用较大的17B变体,微调5000步,并每1000步保存一次检查点。所有MATCHA和Pix2Struct模型都在64个GCP-TPUv3上进行预训练/微调,而PaLI模型在128个GCP-TPUv4上进行微调。
请注意,由于MATCHA是一个图像到文本的模型(没有文本输入分支),每当需要输入文本时,文本会被渲染为图像。例如,对于问答任务,我们将问题作为标题预置在图表上,并将带有问题标题的整个图像输入模型。
我们在表3中总结了主要结果,在三个图表/绘图领域基准任务(ChartQA、PlotQA和Chart-to-Text Summarization)上,将MATCHA与一系列基准模型和SOTA模型进行比较。
在ChartQA任务中,MATCHA比之前的SOTA模型(没有访问底层金标准数据表的Pix2Struct)提高了8.2%。即使我们考虑那些假设存在金标准数据表的模型,它们通常比MATCHA低3-5%。表现最好的基线模型VisionTaPas有一个专门的模块用于建模表格,但仍然落后MATCHA 2.4%。
在PlotQA任务中,MATCHA再次是整体表现最好的模型。在v1集上,能够访问底层数据表的VL-T5比MATCHA好大约4%,这也是直观的,因为PlotQA是一个合成数据集,包含相对简单的查询,v1是提取性的问题集,查询更加直接。
在v2集上,问题涉及到数值推理,MATCHA超越了所有模型,包括那些能够访问底层金标准数据表的模型。
在Chart-to-Text总结任务中,MATCHA在Pew和Statista上都优于Pix2Struct,并且成为Pew任务的新SOTA。然而,MATCHA在Statista任务上表现不如PaLI-17B(res. 588)。
温馨提示:
阅读全文请访问"AI深语解构" MATC:通过数学推理和图表还原增强视觉语言预训练