SATP-GAN:基于自注意力的交通流预测生成对抗网络

998efe5f541331d234fb5e1233b8fd7e.png

文章信息

54046be963153b0b6745a950e06f5ec7.png

《SATP-GAN:self-attention based generative adversarial network for traffic flow prediction》是2021年发表在Transportmetrica B:Transport Dynamics上的一篇文章。

da108be80b36726de712d6c1beacb831.png

摘要

b2002704bb7a8c7fea29a4e760592cb6.png

交通流预测是交通控制和诱导系统中的基本问题之一,近年来随着人工智能的成功,交通流预测的新方法越来越受到重视。本文提出了一种新的时间序列预测模型,即自注意生成对抗网络(SATP-GAN)。SATP-GAN方法基于自我注意和生成对抗网络(GAN)机制,由GAN模块和强化学习(RL)模块组成。在GAN模块中,文章应用自我注意层来捕获时间序列数据的模式,而不是递归神经网络。在RL模块中,文章应用RL算法来调整SATP-GAN模型的参数。文章在真实世界的流量数据集上评估了该框架,并获得了比基线方法6.5%的一致改善。SATP-GAN框架证明了GAN机制在微调参数后也可用于时间序列预测。

7ef255879cb4c441a8202c9c48ff0045.png

介绍

99b9903cddcf3c369b810d4573fbfbf7.png

现有交通流量预测方法的缺点有以下几点:

(1)在现实世界中,预测交通流量的目标常常不能达到缓解交通拥堵的目标。

(2)行业的大规模场景需要多步交通流预测,然而现有的方法在多步交通流预测中表现不佳。

(3)现有方法中的级联依赖结构,特别是递归神经网络(RNNs ),具有大量的能量低效操作,如乘法、存储器访问和非线性变换,不能保证高计算速度和低功耗,这对于边缘和传感器环境是至关重要的。

(4)为了实现最优的交通流预测性能,需要调整超参数优化以适应统计特性。然而,现有的方法通常依赖于手动设置。

受深度学习这些突破的启发,为了解决现有交通流预测方法的局限性,文章提出了一个新的框架,即基于自注意力的GAN时间序列预测模型(SATP-GAN),利用自注意力、GAN和强化学习(RL)机制来预测交通流。以下是文章的贡献:

(1)文章提出了一个新的框架SATP-GAN,它利用自注意力、GAN和RL机制来预测流量。

(2)自注意力机制取代了RNN (LSTM和GRU)的时序性,是并行计算的时间序列预测任务的理想选择。

(3)GAN机制用于生成新的预测数据,采用RL来调整参数。

(4)多步交通流量预测可避免累积误差。时间步长的输出是根据整个历史计算的,而不仅仅是输入,它的当前隐藏状态可以学习长期相关性。

(5)实验表明,与基线相比,性能持续提高6.5-9.1%。

730507c37ed707e7b1783028b865bd2c.png

前期工作和相关工作

98bf17af767d1aa593cab75f09bc5321.png

在这一部分中,文章形式化了交通流预测问题,回顾了交通流预测的相关工作,并介绍了关于自注意力、GAN和RL机制的最新深度学习研究。

交通流预测问题

交通流量预测是在历史观测数据的驱动下,预测一个交叉口未来的车流量。将历史车流量观测值定义为Fhis= (F1,...,Ft),目标是预测多步未来车流量:

56a7e72553381455b7b152ddf6883b60.png

其中预测函数由时间序列预测模型学习,k是多步的数目。

传统的预测方法

传统的交通流量预测方法大多是线性模型。这些方法包括时间序列的自回归统计、支持向量回归(SVR)、浅层人工神经网络(FNNs)等。

ARIMA

ARIMA模型一般表示为ARIMA (p,d,q),其中p,d, q是由非负整数组成的参数,是自回归模型的时滞个数的集合;d表示差值,q是移动平均模型的阶数。ARIMA模型被公式化为:

4fdea4ec882366b00bd59c8dffc2849b.png

SVR

支持向量回归(SVR)是由Drucker等人提出的。该模型应用了与SVM相同的分类原则。在回归问题中,容限(ε)是近似值。主要思想是最小化误差,找到使裕度最大化的超平面,并确保部分误差是容许的。定义一个ε(如图1),虚线定义的区域内数据点的残差为0;虚线区域外的数据点(支持向量)到虚线边界的距离就是残差xi和xi* . SVR就是找到一个最优的条纹区域(2ε宽度),然后在区域外的点上回归。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第1张图片

SVR的公式为:

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第2张图片

FNN

前馈神经网络(FNN)是一种人工神经网络,其中节点之间的连接不形成循环,这是深度学习的主要基础,并由Rumelhart,Hinton和Williams (1987)提出。它被广泛用于预测。FNN由非线性层组成,以西格玛函数(σ函数)作为激活函数。通常它有三层,包括输入层、隐藏层和输出层。每个内层都与前一层完全连接。

54f5a0d24391645010ee2dbe2ebef990.png

近期的方法

最近的交通流量预测方法是基于机器学习和深度学习的方法,这些方法日益受到重视。研究人员使用XGBoost、深度递归神经网络(RNNs)和卷积神经网络(CNN)研究时间序列预测,这些方法已应用于交通预测。

XGBoost

XGBoost(极限梯度提升)是一种集成的机器学习算法,基于Chen和Guestrin (2016)提出的决策树。XGBoost使用梯度增强作为框架,属于GBDT的一个变体。该算法与传统GBDT的区别在于:GBDT使用泰勒一阶展开式进行变换,而XGBoost使用泰勒二阶展开式。而且XGBoost还引入了一个正则项,更容易过拟合。

定义n个样本的m个特征的数据集D = xi,yi,k集后的预测数据为:

3be4c1aa828e999d3c01120c7e5dd58f.png

在这个公式中,F是返回树空间F = {f(x) = wq(x)}(q : Rm→ T,w∈rτ);w是返回树的叶子的重量;qi是回归树的结构,将每个样本节点映射到相应的叶节点索引;t是叶子的数量。为了学习模型的参数,目标函数需要最小。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第3张图片

其中L是损失函数;Ω是对模型复杂性的惩罚。XGboost是一种向前逐步算法,在时间t增加新的模型,需要最小化以下目标函数。

9a1602bbf6c07ff9f5dfc7a652ddebdb.png

该方程使用二阶泰勒展开式进行优化:

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第4张图片

Gated RNN

传统的RNN有两个缺点:梯度爆炸和消失梯度。为了克服这些缺点,Hochreiter等人提出了具有称为长短期记忆的门模块的RNN模型(LSTM,图2显示了结构)。钟等简化了LSTM的计算,并提出了GRU模型。GRU的可训练参数要少得多。然而,LSTMs使用得更多。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第5张图片

一般情况下,LSTM不会单独用于时间序列预测,以提高预测的准确性。受机器翻译成功的启发,自然语言处理(NLP)中的Seq2Seq模型显示出了巨大的潜力。更具体地说,标准Seq2Seq模型由两个关键组件组成,一个是编码器,它将源输入x映射到矢量表示,另一个是解码器,它根据源和编码器生成输出序列。编码器和解码器都是LSTMs,用于捕获不同的组合模式。

GAN

生成对抗网络(GAN)是一个框架,用于以对抗方式估计分布。它同时训练两个模型,即生成模型生成器(G)和判别模型判别器(D)。训练判别器以最大化为来自训练数据和生成数据的样本分配适当标签的概率。用于时间序列预测问题的遗传神经网络应用不多。

(1)生成器是一种神经网络:PG(x;θ).生成器旨在找到PG的参数集θ,使真实数据和生成数据尽可能接近。

(2) 判别器(D)是鉴别哪些数据是PG(x;θ)和来自真实数据的数据。判别器的目标是能够区分真实数据和生成数据,成为一名在工作中不断学习的质量检查员。

8e736738809e2ae7a51422a67cf58ad3.png

在上面的公式中,假设G是固定的。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第6张图片

要找到最佳判别器D,请尝试最大化Pdata(x)logD(x)+PG(x)log(1D(x))。

然而,GAN通常用于生成图像或句子,很少用于时间序列预测。

a06bdf0c9caeb642c5b39fec45912f77.png

方法:STAP-GAN

6b56a36e55a8d439afe9e676fd603484.png

在本节中,文章将详细介绍基于自注意力的生成式对抗网络交通流预测框架(SATP-GAN ),其总体架构如图3所示。SATP-GAN包括两个主要组件:GAN模块(自注意力层用于生成器,卷积神经网络(CNN)用于判别器),以及RL模块,其被设计用于调整GAN的超参数。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第7张图片

以下是SATP-GAN工作原理的概述:

(1) 将历史交通流量同时输入GAN模块。

(2) 基于自注意机制和CNN训练GAN模块以获得预测能力。

(3) 通过RL模块调整GAN参数,使GAN模块稳定。

GAN模块

GAN模块由两部分组成:发生器和判别器。在生成器中,文章应用自注意力层来捕获时间序列数据的模式,而不是RNNs。在判别器中,文章使用CNN来检测预测数据(生成的)和真实数据的特征。

生成器

生成器根据输入的历史数据生成预测的时间序列数据。如图4所示:将历史流数据及其位置放入生成器。此外将预测步长k定义为生成器的约束条件。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第8张图片

一般来说,生成器的输入是:历史交通流量(流量1,流量2,...流量t);预测大小k;正态分布z的数据。

自注意力层捕获输入数据模式,而不是RNNs。此外,它可以通过并行计算来加速。经过自注意力层,前馈层最终生成预测结果。

(1)自我注意层注意机制最近被广泛应用于深度学习的各个领域。从概念上讲,注意力模仿人的感知方式,有选择地过滤掉少量重要信息并聚焦于此,忽略大部分不太重要的信息。注意力信息选择过程的核心在于信息权重系数的计算。自我注意机制是计算自己的注意权重。首先,从交通流数据和数据位置计算输入嵌入,以准备查询、密钥、值嵌入。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第9张图片

自我注意机制分为三个步骤(如图5所示):

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第10张图片

  • 根据查询和关键字计算两者之间的相似性或相关性;

332709a8b1782ddaf6377b3bf63f9dfe.png

  • 使用softmax对步骤1中计算的结果进行归一化,以获得权重系数;

e362b23a4b10519d6affafbaecb2aa52.png

  • 加权系数是值的加权和。

e5ebdcb6607698f5faea278475b39eb7.png

(2)前馈神经网络

引入前馈神经网络用于从自注意力层产生预测数据(如图6所示)。结果,输入就是自注意力层的输出。输出为k步时间序列预测数据。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第11张图片

(3)超参数

生成器的超参数由自注意力层nself_attention的数字层组成;前馈神经网络的层数nFFNN

判别器

判别器输入分别是真实数据和预测数据,输出是分类结果,如预测数据是否真实,以及生成器和判别器的置信度值。此外,判别器可以评估生成器和判别器本身。判别器由CNN和神经网络组成,如图7所示。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第12张图片

以下是判别器的组件:

(1)CNN

CNN可以检测可用于提取时间序列数据中模式信息的特征,并且它们在彼此更接近或彼此更相关的数据点上工作良好。

(2)神经网络

在CNN输出的基础上,文章为不同的目的设计了三种不同的神经网络。第一个神经网络用于检测数据是否来自真实数据;第二个神经网络用于评估判别器;第三个神经网络是评估生成器。

(3)超参数

判别器的超参数是CNN层ncnn的层数和CNN结构nfilter的滤波器数;三个神经网络(ncnn,ndis,ngen)的层数。

RL模块

强化学习是通过构造代理来学习。智能体的主要特征包括与环境交互的能力(根据回合插曲,一个完整的与环境的交互过程就是一个插曲),用于计算分析和决策。代理是一个实体,它可以根据对其环境的观察(感知)采取行动。RL的特点:没有监督数据,只有奖励信号。

在SATP-GAN中,文章应用RL模块来调整生成器和判别器的超参数。如图8所示,RL模块分为两部分,分别使用Q学习算法来调整超参数。Q(s,a)表示行动a中状态s的行动价值函数(在状态s中采取行动a获得的未来回报)。求Q(s,a)函数,寻找某一状态下的最佳行动,使最终获得的累积报酬最大,然后用贝尔曼方程求解。

SATP-GAN:基于自注意力的交通流预测生成对抗网络_第13张图片

Q学习公式如下:

3c4536c54db17c7317a9174a1e87fcc7.png

其中动作a表示上述超参数。奖励r来自判别器。

算法

STAP-GAN算法是:

STAP-GAN算法不仅训练基于自注意的GAN模型,而且使用RL方法来优化生成器(G)和判别器(D)的超参数。生成器(G)和判别器(D)的超参数影响交通流预测的性能,并且需要在面对不同数据集时进行调整。此外,文章使用一个强化学习器来调整超参数。输入交通流数据,初始化生成器(G)和判别器(D)的超参数,同时训练GAN和RL模型。固定生成器(G)的超参数,训练GAN模型,并使用RL模型来控制判别器(D)的超参数,直到RL模型找到最佳超参数。固定判别器(D)的超参数,训练GAN模型,并使用RL模型来控制生成器(G)的超参数,直到RL模型找到最佳超参数。重复这两个步骤,G和D都将在RL模型的帮助下找到它们的最佳超参数。

c4a3863b0a8690382d2202e13c382f1e.png

实验

c570ef57e2762eceb621c9edcd8db5cf.png

在这一部分中,文章报告了通过在交通流数据集上应用SATP-GAN和其他时间序列预测模型进行的实验,以及收集的模型性能的结果。

文章收集了中国某城市一个十字路口的车流量,包含了一个月内每小时的车辆数。交通流分为四个不同的方向:西、东、北、南。4个方向的小时交通流量,共一个月(30天),2880个数据。

7755b4d414986de003a02be5304274f8.png

ATTENTION

a95e10f41d64be2abef8edac5ee2be25.png

如果你和我一样是轨道交通、道路交通、城市规划相关领域的,可以加微信:Dr_JinleiZhang,备注“进群”,加入交通大数据交流群!希望我们共同进步!

你可能感兴趣的:(神经网络,大数据,算法,python,计算机视觉)