Advertisement

TFT:Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting

阅读量:

tft

  • 1 模型简述

    • 1.1 输入
    • 1.2 输出
  • 2 损失函数

  • 3 模型结构

    • 3.1 基本结构
    • 3.2 整体结构
  • 4 模型参数说明

    • 4.1 所有参数
    • 4.2 列定义参数
    • 4.3 处理后输入
  • 5 TFT贡献

  • 6 其他总结

  • 参考资料

1 模型简述

tft模型具有下面特征:

  • 能够处理多条时间序列数据
  • 采用注意力机制构建的网络架构
  • 具备可解释性设计
  • 实现特征选择功能的同时,并通过门控机制实现特征压缩
  • 计算效率较高

1.1 输入

输入数据为df格式,列可分为下面六类

  • 目标变量:预测目标变量
  • 观测型输入变量:观测型输入变量(如上一时刻的时间序列值等无法提前预知的)
  • 已知型输入变量:已知型输入变量(如日期、节假日等可提前预知的信息)
  • 静态型输入变量:静态型输入变量(如商店地址等固定不变的信息)
  • 时间索引标识符:时间索引标识符不作为模型输入使用仅用于内部索引标识
  • 时间索引标识符不作为模型输入使用仅用于内部索引标识

1.2 输出

各分位数的预测值:
比如:quantiles = [0.1, 0.5, 0.9]
模型就会给出0.1,0.5,0.9分位数的预测值

2 损失函数

分位数损失函数

在这里插入图片描述

3 模型结构

在这里插入图片描述

说明:

  • Feature Selection is aimed at intelligent feature extraction from input data.
  • Gated Residual Network blocks support effective information transmission via skip connections and gating layers.
  • Temporal Processing Mechanism integrates information through LSTM networks for handling localized temporal patterns and employs multi-head attention to capture cross-temporal dependencies.

3.1 基本结构

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.2 整体结构

在这里插入图片描述

4 模型参数说明

4.1 所有参数

  • 丢弃率:0.1 ----防止过拟合
  • 隐藏层单元数:5 ----隐藏层大小
  • 学习率:0.001 ----学习率
  • 最大梯度范数:0.01 ----Adam优化器中的梯度裁剪方法(clipnorm),即裁剪后的标准化值不超过该阈值。
  • 批次大小:64 ----批次大小
  • 模型存储路径:./outputs/ ----模型保存路径
  • 注意力头数量:4 ----Transformer中的注意力头个数
  • 注意力堆叠层数:1 ----自注意力层堆叠层数,默认设置为1,在代码中未被使用也不会产生实际影响
  • 编码器与解码器总时间步数:192 ----编码器长度与解码器长度之和
  • 编码器时间步数:168 ----编码器长度
  • 训练周期数:1 ----训练轮次数
  • 早停等待次数:10 ----早停策略中的等待期次数
  • 多进程运行数目:5 ----多进程数量
  • 列定义列表:
    (id, 0, 4), # 列定义
    (小时起始时间, 0, 5),
    (功率消耗, 0, 0),
    (小时, 0, 2),
    (星期几, 0, 2),
    (小时起始时间, 0, 2),
    (类别ID, 1,3)

input_dimension: 5, ----输入维度
output_dimension: 1,----输出维度
category_information: [369],----所有类别信息
observed_input_location: [0], ----观察位置
static_input_location: [4], ----静态位置
known_non_categorical_inputs: [1, 2, 3], ----已知非类别位置
known_categorical_inputs: [0] ----已知类别位置

4.2 列定义参数

column_definition 输入格式为(列名,数据类型, 输入类型)],

  • 其中数据类型 有三类
    0 - 实数
    1 - 类别
    2 - 日期

数据类型分为六种不同的类别:

  • 第1类为目标列(0- TARGET)
  • 第2类为观测输入(1 - OBSERVED_INPUT)
  • 第3类为已知输入(2 - KNOWN_INPUT)
  • 第4类为静态输入(3 - STATIC_INPUT)
  • 第5类用于标识时间序列的编号(4 - ID)
  • 第6类为时间索引(5-TIME)
  • input_size = 全体特征 - 2(id, time), 因为id和time不参与输入处理, 仅用于索引标识。
  • 剩余的所有输入字段需划分为两种类型: 类别型字段与常规字段。
  • 类别型字段数量等于self.category_counts的长度:
  • 常规字段数量等于input_size减去类别型字段数量:
  • regular Inputs等于all Inputs在前(num_regular)个维度上的切片: 它们都会经过real_to_embedding转换:
  • 类别型Inputs等于all Inputs在后(num_categoryical)个维度上的切片: 它们都会被嵌入化处理:

4.3 处理后输入

数据处理中会将所有输入分为四类:

  • obs_inputs, 观测输入
    = input_obs_loc 所有观测的值

static_inputs, 静态输入
= regular_inputs 中属于 静态类别 + categorical_inputs 中属于 静态类别

  • known_combined_layer, 已知输入
    = known_non_static_regular_inputs + known_non_static_categorical_inputs

  • unknown_inputs, 未知输入
    = regular inputs中属于非识别且未被观察到的数据 + categorical inputs 中属于非识别且未被观察到的数据

5 TFT贡献

  • 门控机制:通常与add&norm协同工作。该层结构有助于模型适应不同深度和复杂度的需求,并因此适用于多样化的数据类型和场景。
  • 特征选择网络:有助于模型在每个阶段聚焦于相关特征。
  • 静态协变量编码:用于将静态变量编码整合为辅助上下文信息加入网络。
  • seq2seq层结构用于捕获短期依存关系。
  • 多头注意力机制用于捕获长期依存关系,并通过多头并行处理不同粒度的信息。
  • 采用百分位数确定预测区间范围,并据此生成置信区间估计。

6 其他总结

门控结构的作用
----门控结构的功能:通过门控机制实现对有用信息特征的专注以及对无关干扰信息的抑制,在保证模型性能的同时使其具备根据不同任务需求进行计算深度与复杂度调节的能力,并因此使得该模型能够适应多样化的数据集与应用场景。

GRN的作用体现在以下几个方面:首先,它能够促进信息流通更加高效;其次,在网络结构中通过采用跳跃连接和门控机制;第三,从而能够有效捕捉关键特征;第四,并确保不会遗漏任何重要信息。

  1. 为什么lstm捕捉的是局部信息? 为什么多头注意力模块捕捉的是全局信息?
    ----lstm是直接处理输入数据的,会处理historical_features, future_features,还要间接处理static特征。最终得到状态temporal_feature_layer。
    ----而多头注意力模块没有直接处理输入的特征数据,而是处理的temporal_feature_layer和静态特征数据。多头注意力模块的输出为decoder。还有相应的注意力self_att。
    ----输出值是如何得到的呢?输出值是将多头注意力模块的输出decoder和lstm的输出temporal_feature_layer一起做add&norm, 然后再做Dense得到的。
    ----因为lstm离输入数据近,受输入数据影响大,所以最可能捕捉的是局部信息。
    ----因为多头注意力模块处理的是lstm的输出状态,离原始输入较远,所以抽象层次更高,受原始输入影响较小,所以捕捉到的最可能是全局性的信息。
    ----所以,输出值是由lstm代表的局据信息+多头注意力模块代表的全局信息共同得到的。二者缺一不可。
    ----通过观察多头注意力模块的输出self_att,我们可以看到,确实捕捉的是类似周期性这种的全局信息。
    ----因为self_att的维度也是固定的,如果有周期性的话,那么self_att中的注意力数据应该也会呈现出周期性,因为i时刻的模式肯定与i-T时刻的模式很像,i时刻对i-T时刻的关注肯定会多一点,有点类似acf自相关函数。
    如下图所示:天周期每24h会有一个峰值,星期周期每7d会有一个峰值。
在这里插入图片描述
  1. 特征重要性
    ----特征重要性由static_weights, historical_flags, future_flags表示为,它们体现了特征选择权重.

参考资料

TFT相关文章链接
Google Research GitHub仓库实现代码

全部评论 (0)

还没有任何评论哟~