论文笔记《Self-Attention ConvLSTM for Spatiotemporal Prediction》
发布时间
阅读量:
阅读量
目录
-
1. Abstract
-
2. Introduction
-
3. Method
-
- 3.1 模型整体结构
- 3.2 SAM模块
-
- 3.2.1 Feature Aggregation 特征聚合
- 3.2.2 Memory Updating 记忆更新
- 3.2.3 Output 输出
-
4. Experiments
-
- 4.1 Implementation
- 4.2 Ablation Study
1. Abstract
- 为提取空间特征的全局和局部依赖性,本文向ConvLSTM引入了一个新的自注意力机制(self-attention mechanism)
- 子注意力记忆模块(self-attention memory , SAM) 能在时空域记住那些具有长期依赖性 的特征
2. Introduction
本文的创新点/贡献在于:
- 提出一个新的基于ConvLSTM的变体模型用于时空预测,命名为SA-ConvLSTM,特点是能很好捕获长程空间依赖性 。
- 设计了一个基于记忆的自相关模块(memory-based self-attention module, SAM),该模块用于在预测中记住全局的时空依赖性。
- 为验证模型(1)使用MovingMNIST 和 KTH 数据集进行多框架预测;(2)使用TaxiBJ数。据集预测交通流量。本文模型优势是参数更少、效率更高。
写作思路:
- 首段:(1)交待时空预测研究的重要性、现有研究(很简要),说明值得研究;(2)时空预测具有复杂动态性,时空领域都表现出依赖性。
- 第二段:(1)ConvLSTM 效果不错;(2)存在问题1-长程依赖可以通过堆叠的卷积层捕获,但有效感受野要比理论上的感受野小很多;(3)存在问题2-离特殊位置较远的特征,要体现位置的影响 实现前馈和反向传播,就要经过很多层,这样一来训练时的优化就很困难;(4)现有的解决办法只能提供稀疏的依赖关系,估计的是局部感受野;(5)因此现有问题就是如何让ConvLSTM捕获到长程依赖性。
- 第三段:(1)认为自注意力模块相对于卷积操作,更擅于获得全局空间上下文信息(注意:这里只是说普通的self-attention module),因此本文使用额外的记忆单元 \mathcal{M} ;(2)\mathcal{M} 也能像LSTM 通过门控机制捕获长程的时间依赖性。
3. Method
注:原文比较简略,下文按照自己的理解重新组织了顺序
3.1 模型整体结构
这篇文章创新点就是加了一个基于记忆的自相关模块(memory-based self-attention module, SAM),这个模块是接在ConvLSTM模型的最后的,如图浅绿色部分(如果没有它及其输出,这个图就是ConvLSTM模型图,或者说是LSTM模型图):

3.2 SAM模块

这个模块看上去好复杂,基于文章描述 它可以分为三个小部分,我在图上用不同色块标注出来(强迫症不允许色块对不齐,是不是很整齐hh):
- 黄色区域:特征聚合,文章中的Feature Aggregation 部分
- 蓝色区域:记忆更新,文章中的Memory Updating 部分
- 绿色区域:输出,文章中的Output 部分
3.2.1 Feature Aggregation 特征聚合

整个黄色区域可分为两部分:
- 上半部分(黄色):输入是当前时刻特征 \mathcal{H_t},经历一个普通的self-attention 模块,得到Z_h。
- 下半部分(灰色):输入是上一时刻记忆 \mathcal{M}_{t-1},也是经历一个self-attention 模块。不同的是,此处用的query Q 是当前时刻计算得到的,key K 是上一时刻 \mathcal{M}_{t-1} 计算得到的,通过 \mathbf{e}=\mathbf{Q}_{h}^{T} \mathbf{K}_{h} \in \mathbb{R}^{N \times N} 计算相似性得分,然后再经过 softmax 将得分映射至 (0,1) 区间。最后再将得分与上一时刻记忆 \mathcal{M}_{t-1} 的值相乘,得到Z_m。
- 通过通道相连将这两个输出拼一起,再乘权重,得到Z。
- Z 再与当前时刻特征 \mathcal{H_t} 拼接到一起,作为下一步骤的输入。
3.2.2 Memory Updating 记忆更新

感觉这部分记忆更新操作和GRU操作很像,具体操作如下,分两步走:
- 通过tanh 对输入数据处理,将其映射到[-1,1]。g_{t}^{\prime}=\tanh \left(W_{m ; z g} * \mathbf{Z}+W_{m ; h g} * \mathcal{H}_{t}+b_{m ; g}\right)
- 通过sigmoid 处理数据,将其映射到[0,1]上,形成gate。i_{t}^{\prime}=\sigma\left(W_{m ; z i} * \mathbf{Z}+W_{m ; h i} * \mathcal{H}_{t}+b_{m ; i}\right)
- 最后更新记忆信息,\mathcal{M}_{t}=\left(1-i_{t}^{\prime}\right) \circ \mathcal{M}_{t-1}+i_{t}^{\prime} \circ g_{t}^{\prime}
3.2.3 Output 输出

最后就是输出:
- 先门控处理 o_{t}^{\prime}=\sigma\left(W_{m ; z o} * \mathbf{Z}+W_{m ; h o} * \mathcal{H}_{t}+b_{m ; o}\right)
- 输出 \hat{\mathcal{H}}_{t}=o_{t}^{\prime} \circ \mathcal{M}_{t}
4. Experiments
4.1 Implementation
- 设计为一个4层的网络,每一层有64隐藏层
- ADAM optimizer
- 初始学习率为0.001
- 训练中的mini-batch=8,80000次迭代后收敛
- MovingMNIST 和 TaxiBJ 数据集使用L2 loss
- KTH数据集使用L1+L2 loss
- 指标:SSIM(structural similarity Index Measure);MSE;MAE
4.2 Ablation Study
- 标准的4层ConvLSTM
- 只有self-attention的ConvLSTM模型
- 只有additional memory cell \mathcal{M} 的ConvLSTM模型
- 没有 Z_m 的SA-ConvLSTM 模型
- 完整的SA-ConvLSTM 模型
在结果分析部分:
分析1:Ablation Study

分析2:不同模型之间的比较


分析3:MovingMNIST数据集的定性比较(用过去10帧预测未来10帧)

分析4: TaxiBJ数据集的定性比较(用过去4帧预测未来4帧,即两小时)

颜色越亮,绝对误差越高。虽然感觉不太方便对比,但是看上去很炫酷,有人知道这是用什么画出来的吗 🤔
最后还有两个图像识别任务的可视化结果,然后就是简短的总结,不翻译了~
全部评论 (0)
还没有任何评论哟~
