【Mo 人工智能技术博客】时间序列预测——DA-RNN模型_da rnn-程序员宅基地

技术标签: python  rnn  机器学习  深度学习  人工智能  

时间序列预测——DA-RNN模型

作者:梅昊铭

1. 背景介绍

传统的用于时间序列预测的非线性自回归模型(NRAX)很难捕捉到一段较长的时间内的数据间的时间相关性并选择相应的驱动数据来进行预测。本文将介绍一种基于 Seq2Seq 模型(Encoder-Decoder 模型)并结合 Attention 机制的时间序列预测方法。作者提出了一种双阶段的注意力机制循环神经网络模型(DA-RNN),能够很好的解决上述两个问题。

模型的第一部分,我们引入输入注意力机制在每个时间步选择相应的输入特征。模型的第二部分,我们使用时间注意力机制在整个时间步长中选择相应的隐藏层状态。通过这种双阶段注意力机制,我们能够有效地解决一些时序预测方面的问题。我们将这两个注意力机制模型集成在基于 LSTM 的循环神经网络中,使用标准反向传播进行联合训练。

2. DA-RNN 模型

2.1 输入与输出

输入:给定 n 个驱动序列(输入特征), X = ( x 1 , x 2 , . . . , x n ) T = ( x 1 , x 2 , . . . , x T ) ∈ R n × T X = (x^1,x^2,...,x^n)^T = (x_1,x_2,...,x_T) \in R^{n \times T} X=x1,x2,...,xnT=(x1,x2,...,xT)Rn×T T T T 表示时间步长, n n n 表示输入特征的维度。

输出: y ^ T = F ( y 1 , . . . , y T − 1 , x 1 , . . . , x T ) \hat{y}_{T}= F(y_1,...,y_{T-1},x_1,...,x_T) y^T=F(y1,...,yT1,x1,...,xT) ( y 1 , . . . , y T − 1 ) (y_1,...,y_{T-1}) (y1,...,yT1)表示预测目标过去的值,其中 y t ∈ R y_t\in R ytR ( x 1 , . . . , x T ) (x_1,...,x_T) (x1,...,xT) 为时间 T T T n n n 维的外源驱动输入序列, x t ∈ R n x_t \in R^n xtRn F ( ⋅ ) F(\cdot) F() 为模型需要学习的非线性映射函数。

2.2 模型结构

DA-RNN 模型是一种基于注意力机制的 Encoder-Decoder 模型。在编码器部分,我们引入了输入注意力机制来选择相应的驱动序列;在解码器部分,我们使用时间注意力机制来选择整个儿时间步长中相应的隐藏层状态。通过这个两种注意力机制,DA-RNN 模型能够选择最相关的输入特征,并且捕捉到较长时间内的时间序列之间的依赖关系,如图1所示。


图 1:DA-RNN 模型结构

读者福利:知道你对人工智能、Python 感兴趣,小Mo 便精心准备了这门适合零基础小白学习的《人工智能导论》9.9元 浙大教授吴超老师带你进入AI大门!

2.3 编码器

编码器本质上是一个 RNN 模型,它能够将输入序列转换为一种特征表示,我们称之为隐藏层状态。对于时间序列预测问题,给定输入 X = ( x 1 , x 2 , . . . , x T ) ∈ R n × T , x t ∈ R n X = (x_1,x_2,...,x_T) \in R^{n \times T},x_t \in R^n X=(x1,x2,...,xT)Rn×T,xtRn,在时刻 t t t ,编码器将 x t x_t xt 映射为 h t h_t ht h t = f 1 ( h t − 1 , x t ) h_t = f_1(h_{t-1},x_t) ht=f1(ht1,xt) h t ∈ R m h_t \in R^m htRm 表示编码器隐藏层在时刻 t t t 的状态, m m m 表示隐藏层的维度,KaTeX parse error: Expected group after '_' at position 2: f_̲ 为非线性激活函数,本文中我们使用 LSTM。

本文中,我们提出了一种输入注意力机制编码器。它能够适当地选择相应的驱动序列,这对时间序列预测是至关重要的。我们通过确定性注意力模型来构建一个输入注意力层。它需要将之前的隐藏层状态 h t − 1 h_{t-1} ht1 和** LSTM** 单元的** cell **状态 s t − 1 s_{t-1} st1 作为该层的输入得到:
e t k = v e T t a n h ( W e [ h t − 1 ; s t − 1 ] + U e x k ) e^k_t = v^T_etanh(W_e[h_{t-1};s_{t-1}]+U_ex^k) etk=veTtanh(We[ht1;st1]+Uexk),其中 v e ∈ R T , W e ∈ R T × 2 m , U e ∈ R T × T v_e \in R^T,W_e \in R^{T \times 2m},U_e \in R^{T \times T} veRT,WeRT×2m,UeRT×T是需要学习的参数。
输入注意力层的输出 ( e t 1 , e t 2 , . . . , e t n ) (e^1_t,e^2_t,...,e^n_t) (et1,et2,...,etn) 输入到 softmax 层得到 α t k \alpha_t^k αtk 以确保所有的注意力权重的和为1, α t k \alpha_t^k αtk 表示在时刻 t t t k k k 个输入特征的重要性。

得到注意权重后,我们可以自适应的提取驱动序列 x ~ t = ( α t 1 x t 1 , α t 2 x t 2 , . . . , α t n x t n ) \tilde x_t = (\alpha^1_tx^1_t,\alpha^2_tx^2_t,...,\alpha^n_tx^n_t) x~t=(αt1xt1,αt2xt2,...,αtnxtn),此时我们更新隐藏层的状态为 h t = f 1 ( h t − 1 , x ~ t ) h_t = f_1(h_{t-1},\tilde x_t) ht=f1(ht1,x~t)

2.4 解码器

为了预测输出 y ^ T \hat y_T y^T,我们使用另外一个 LSTM 网络层来解码编码器的信息,即 隐藏层状态 KaTeX parse error: Expected group after '_' at position 2: h_̲。当输入序列过长时,传统的Encoder-Decoder 模型效果会急速恶化。因此,在解码器部分,我们引入了时间注意力机制来选择相应的隐藏层状态。

与编码器中注意力层类似,解码器的注意力层也需要将之前的隐藏层状态 d t − 1 d_{t-1} dt1LSTM 单元的cell状态 s t − 1 ′ s'_{t-1} st1 作为该层的输入得到该层的输出:
l t i = v d T t a n h ( W d [ d t − 1 ; s t − 1 ′ ] + U d h i ) l^i_t = v^T_dtanh(W_d[d_{t-1};s'_{t-1}]+U_dh_i) lti=vdTtanh(Wd[dt1;st1]+Udhi),其中 v d ∈ R m , W d ∈ R m × 2 p , U e ∈ R m × m v_d \in R^m,W_d \in R^{m \times 2p},U_e \in R^{m \times m} vdRm,WdRm×2p,UeRm×m是需要学习的参数。通过 softmax 层,我们可以得到第 i i i 个编码器隐藏状态 h i h_i hi 对于最终预测的重要性 β t i \beta^i_t βti。解码器将所有的编码器隐藏状态按照权重求和得到文本向量 c t = ∑ i = 1 T β t i h i c_t = \sum_{i=1}^T \beta_t^ih_i ct=i=1Tβtihi,注意 c t c_t ct 在不同的时间步是不同的。

在得到文本向量之后,我们将其和目标序列结合起来得到 y ~ t − 1 = w ~ T [ y t − 1 ; c t − 1 ] + b ~ \tilde y_{t-1} = \tilde w^T[y_{t-1};c_{t-1}]+\tilde b y~t1=w~T[yt1;ct1]+b~。利用新计算得到的 y ~ t − 1 \tilde y_{t-1} y~t1,我们来更新解码器隐藏状态 d t = f 2 ( d t − 1 , y ~ t − 1 ) d_t=f_2(d_{t-1},\tilde y_{t-1}) dt=f2(dt1,y~t1),我们使用 LSTM 来作为激活函数 f 2 f_2 f2
通过 DA-RNN 模型,我们预测 y ^ T = F ( y 1 , . . . , y T − 1 , x 1 , . . . , x T ) = v y T ( W y [ d T ; c T ] + b w ) + b v \hat y_T = F(y_1,...,y_{T-1},x_1,...,x_T) = v_y^T(W_y[d_T;c_T]+b_w)+b_v y^T=F(y1,...,yT1,x1,...,xT)=vyT(Wy[dT;cT]+bw)+bv

2.5 训练过程

在该模型中,作者使用平均方差作为目标函数,利用 Adam 优化器,min-batch 为128来进行参数优化。
目标函数:
O ( y T , y ~ T ) = 1 N ∑ i = 1 N ( y ^ T i − y T i ) 2 O(y_T,\tilde y_T)=\frac{1}{N}\sum_{i=1}^N(\hat y^i_T-y_T^i)^2 OyT,y~T=N1i=1N(y^TiyTi)2

3. 实验

3.1 数据集

本文的作者采用了,两种不同的数据集来测试验证 DA-RNN 模型的效果。这里我们仅对 NASDAQ 100 Stock 数据集进行介绍。作者根据 NASDAQ 100 Stock 收集了 81 家主要公司的股票价格作为驱动时间序列,NASDAQ 100 的股票指数做目标序列。数据收集的频率为一分钟一次。该数据集包含了从2016年7月26日至2016年12月22日总共105天的数据。在本实验中,作者使用 35100 条数据作为训练集,2730条数据作为验证集,以及最后2730条数据作为测试集。

3.2 参数设置和评价指标

时间窗口的大小 T ∈ { 3 , 5 , 10 , 15 , 25 } T \in \{3,5,10,15,25\} T{ 3,5,10,15,25}。实验表明 :T=10 时,模型在验证集上的效果最好。编码器和解码器隐藏层的大小 m , p ∈ { 16 , 32 , 64 , 128 , 256 } m ,p\in\{16,32,64,128,256\} m,p{ 16,32,64,128,256}。当 m = p = 64 , 128 m=p=64,128 m=p=64,128 时,实验效果最好。

为评估模型的效果,我们考虑了三种不同的评价指标:RSME,MAE,MAPE。

3.3 模型预测

为展示 DA-RNN 模型的效果,作者将该模型和其他的模型在两个不同的数据集上的预测效果进行了对比,如表1所示。由表1可以看出,DA-RNN模型相对于其他模型,误差更小一些。DA-RNN模型在时间序列预测方面具有良好的表现。

表 1:SML 2010数据集和纳斯达克100股票数据集的时间序列预测结果

为了更好的视觉比较,我们将Encoder-Decoder 模型,Attention RNN 和 DA-RNN 模型的在纳斯达克100股票数据集上的预测结果在图2中展示出来。我们不难看出DA-RNN模型能更好地反映真实情况。

图 3:三种模型在纳斯达克100股票数据集上的预测结果

4. 总结

在本文中,我们介绍了一种基于注意力机制的双阶段循环神经网络模型。该模型由两部分组成:Encoder 和 Decoder。在编码器部分,我们引入了输入注意力机制来对输入特征进行特征提取,为相关性较高的特征变量赋予更高的权重;在解码器部分,我们通过时间注意力机制为不同时间 t t t 的隐藏状态赋予不同的权重,不断地更新文本向量,来找出时间相关性最大的隐藏层状态。Encoder 和 Decode 中的注意力层分别从空间和时间上来寻找特征表示和目标序列之间的相关性,为不同的特征变量赋予不同的权重,以此来更准确地预测目标序列。
项目源码地址:https://momodel.cn/workspace/5da8cc2ccfbef78329c117ed?type=app

5. 参考资料

  1. 论文:A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction
  2. 注意力机制详解:https://blog.csdn.net/BVL10101111/article/details/78470716
  3. 项目源码:https://github.com/chensvm/A-Dual-Stage-Attention-Based-Recurrent-Neural-Network-for-Time-Series-Prediction
  4. 数据集:https://cseweb.ucsd.edu/~yaq007/NASDAQ100_stock_data.html

欢迎关注我们的微信公众号:MomodelAI

同时,欢迎使用 「Mo AI编程」 微信小程序

以及登录官网,了解更多信息:Mo 平台

Mo,发现意外,创造可能

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_44015907/article/details/102860962

智能推荐

攻防世界_难度8_happy_puzzle_攻防世界困难模式攻略图文-程序员宅基地

文章浏览阅读645次。这个肯定是末尾的IDAT了,因为IDAT必须要满了才会开始一下个IDAT,这个明显就是末尾的IDAT了。,对应下面的create_head()代码。,对应下面的create_tail()代码。不要考虑爆破,我已经试了一下,太多情况了。题目来源:UNCTF。_攻防世界困难模式攻略图文

达梦数据库的导出(备份)、导入_达梦数据库导入导出-程序员宅基地

文章浏览阅读2.9k次,点赞3次,收藏10次。偶尔会用到,记录、分享。1. 数据库导出1.1 切换到dmdba用户su - dmdba1.2 进入达梦数据库安装路径的bin目录,执行导库操作  导出语句:./dexp cwy_init/[email protected]:5236 file=cwy_init.dmp log=cwy_init_exp.log 注释:   cwy_init/init_123..._达梦数据库导入导出

js引入kindeditor富文本编辑器的使用_kindeditor.js-程序员宅基地

文章浏览阅读1.9k次。1. 在官网上下载KindEditor文件,可以删掉不需要要到的jsp,asp,asp.net和php文件夹。接着把文件夹放到项目文件目录下。2. 修改html文件,在页面引入js文件:<script type="text/javascript" src="./kindeditor/kindeditor-all.js"></script><script type="text/javascript" src="./kindeditor/lang/zh-CN.js"_kindeditor.js

STM32学习过程记录11——基于STM32G431CBU6硬件SPI+DMA的高效WS2812B控制方法-程序员宅基地

文章浏览阅读2.3k次,点赞6次,收藏14次。SPI的详情简介不必赘述。假设我们通过SPI发送0xAA,我们的数据线就会变为10101010,通过修改不同的内容,即可修改SPI中0和1的持续时间。比如0xF0即为前半周期为高电平,后半周期为低电平的状态。在SPI的通信模式中,CPHA配置会影响该实验,下图展示了不同采样位置的SPI时序图[1]。CPOL = 0,CPHA = 1:CLK空闲状态 = 低电平,数据在下降沿采样,并在上升沿移出CPOL = 0,CPHA = 0:CLK空闲状态 = 低电平,数据在上升沿采样,并在下降沿移出。_stm32g431cbu6

计算机网络-数据链路层_接收方收到链路层数据后,使用crc检验后,余数为0,说明链路层的传输时可靠传输-程序员宅基地

文章浏览阅读1.2k次,点赞2次,收藏8次。数据链路层习题自测问题1.数据链路(即逻辑链路)与链路(即物理链路)有何区别?“电路接通了”与”数据链路接通了”的区别何在?2.数据链路层中的链路控制包括哪些功能?试讨论数据链路层做成可靠的链路层有哪些优点和缺点。3.网络适配器的作用是什么?网络适配器工作在哪一层?4.数据链路层的三个基本问题(帧定界、透明传输和差错检测)为什么都必须加以解决?5.如果在数据链路层不进行帧定界,会发生什么问题?6.PPP协议的主要特点是什么?为什么PPP不使用帧的编号?PPP适用于什么情况?为什么PPP协议不_接收方收到链路层数据后,使用crc检验后,余数为0,说明链路层的传输时可靠传输

软件测试工程师移民加拿大_无证移民,未受过软件工程师的教育(第1部分)-程序员宅基地

文章浏览阅读587次。软件测试工程师移民加拿大 无证移民,未受过软件工程师的教育(第1部分) (Undocumented Immigrant With No Education to Software Engineer(Part 1))Before I start, I want you to please bear with me on the way I write, I have very little gen...

随便推点

Thinkpad X250 secure boot failed 启动失败问题解决_安装完系统提示secureboot failure-程序员宅基地

文章浏览阅读304次。Thinkpad X250笔记本电脑,装的是FreeBSD,进入BIOS修改虚拟化配置(其后可能是误设置了安全开机),保存退出后系统无法启动,显示:secure boot failed ,把自己惊出一身冷汗,因为这台笔记本刚好还没开始做备份.....根据错误提示,到bios里面去找相关配置,在Security里面找到了Secure Boot选项,发现果然被设置为Enabled,将其修改为Disabled ,再开机,终于正常启动了。_安装完系统提示secureboot failure

C++如何做字符串分割(5种方法)_c++ 字符串分割-程序员宅基地

文章浏览阅读10w+次,点赞93次,收藏352次。1、用strtok函数进行字符串分割原型: char *strtok(char *str, const char *delim);功能:分解字符串为一组字符串。参数说明:str为要分解的字符串,delim为分隔符字符串。返回值:从str开头开始的一个个被分割的串。当没有被分割的串时则返回NULL。其它:strtok函数线程不安全,可以使用strtok_r替代。示例://借助strtok实现split#include <string.h>#include <stdio.h&_c++ 字符串分割

2013第四届蓝桥杯 C/C++本科A组 真题答案解析_2013年第四届c a组蓝桥杯省赛真题解答-程序员宅基地

文章浏览阅读2.3k次。1 .高斯日记 大数学家高斯有个好习惯:无论如何都要记日记。他的日记有个与众不同的地方,他从不注明年月日,而是用一个整数代替,比如:4210后来人们知道,那个整数就是日期,它表示那一天是高斯出生后的第几天。这或许也是个好习惯,它时时刻刻提醒着主人:日子又过去一天,还有多少时光可以用于浪费呢?高斯出生于:1777年4月30日。在高斯发现的一个重要定理的日记_2013年第四届c a组蓝桥杯省赛真题解答

基于供需算法优化的核极限学习机(KELM)分类算法-程序员宅基地

文章浏览阅读851次,点赞17次,收藏22次。摘要:本文利用供需算法对核极限学习机(KELM)进行优化,并用于分类。

metasploitable2渗透测试_metasploitable2怎么进入-程序员宅基地

文章浏览阅读1.1k次。一、系统弱密码登录1、在kali上执行命令行telnet 192.168.26.1292、Login和password都输入msfadmin3、登录成功,进入系统4、测试如下:二、MySQL弱密码登录:1、在kali上执行mysql –h 192.168.26.129 –u root2、登录成功,进入MySQL系统3、测试效果:三、PostgreSQL弱密码登录1、在Kali上执行psql -h 192.168.26.129 –U post..._metasploitable2怎么进入

Python学习之路:从入门到精通的指南_python人工智能开发从入门到精通pdf-程序员宅基地

文章浏览阅读257次。本文将为初学者提供Python学习的详细指南,从Python的历史、基础语法和数据类型到面向对象编程、模块和库的使用。通过本文,您将能够掌握Python编程的核心概念,为今后的编程学习和实践打下坚实基础。_python人工智能开发从入门到精通pdf

推荐文章

热门文章

相关标签