论文阅读笔记:Multi-Agent Tensor Fusion for Contextual Trajectory prediction
研究者旨在将场景纹理带来的约束与行人行为的随机性纳入统一建模框架。研究者认为,在应对未知情境时行人的轨迹预测面临挑战:这源于智能体运动具有的随机性特征,并受到目标识别、社交互动以及场景纹理等多种因素的影响;此外,在这种复杂环境下还需要具备泛化能力以适应不同情境的变化情况——因为其他智能体的数量及其配置并非固定不变。针对这一问题,在以往的研究中主要采用基于代理或基于空间位置的不同编码策略;其中基于代理的方法通过聚合函数(如Social-LSTM)处理多个智能体的特征向量;而基于空间的位置方法则直接操作场景的自上而下表示。
该研究者设计了一种基于多智能体张量融合(MATF)的编码解码架构。该架构以其为中心视角方法著称,并能在空间维度上保留所有智能体及其环境的空间布局信息。研究者将每个主体的历史行为进行独立编码,并将其与场景整体进行对齐处理以保证各主体及其环境的空间特征对应关系。随后利用全卷积神经网络构建了一个融合后的多主体张量并在此基础上提取各主体之间的空间交互特征。为了捕捉未来轨迹预测中的不确定性问题在模型中引入了条件生成对抗训练机制来描述有限样本集下的轨迹分布情况。
(1)MATF编码-解码

在MATF框架中存在两个协同工作的编码流系统:第一系统为每个智能体单独拥有一个长短期记忆网络(LSTM),专门用于捕获其自身的历史运动轨迹信息;所有这些LSTM网络通过共享一组权重参数实现了统一的学习机制;第二系统则负责编码固定环境中的静态场景上下文图像信息

的CNN。LSTM输入该代理的过去轨迹,输出一个一维的状态向量

,场景上下文编码器接收一个鸟瞰视图的原始图像(或包含所有静态对象的分割图像),并生成一个缩放特征图以保留场景空间结构信息

。
随后

其嵌入到场景特征图中的一个空间特征图。各智能体通过其LSTM输出生成的状态向量依据其坐标位置进行排列,从而形成代理通道。该代理通道与场景上下文相关联后形成的场景特征图实现对齐,从而保留了智能体与场景间的空间结构信息。其中,每个智能体的LSTM输出的状态编码向量

其在轨迹末端的时间步长处的空间坐标被放置于一个初始化为空(即全零)的俯视图空间张量中,并与场景编码后的特征图进行匹配

相同的宽度和高度(即相同尺寸)。随后将此张量与场景特征图进行融合处理后得到一个综合张量;在此过程中若出现多个智能体向量被放置在同一张量单元的情况,则对该单元执行元素的最大值聚合操作。
多代理张量被输入至全卷积层中,在该层中模型能够学习多个代理间的互动关系以及各代理与场景背景之间的相互作用信息,并通过保持局部空间特性来生成综合性的多代理张量表示。具体而言,在不同空间尺度上模仿U-Net架构以建模交互关系。融合模块输出特征图

具有和场景特征图

相同的宽高,以保持空间结构。 最后,把

中每个代理对应的向量

基于它们的坐标定位后,该系统整合了围绕i区域内的代理与场景相关信息,将其作为残差项补充到原有代理特征向量中,从而构建出最终的编码代理向量

,独立解码以获取未来T’时刻的预测轨迹

其中值得注意的是,在这种情况下, 因为每个智能体LSTM共享相同的权重参数, 在经过正向传播后形成的每个智能体对应的融合向量都得到了生成.
作者采用了条件生成对抗网络来训练一个随机生成机制,并以捕捉预测的不确定性为目标。此外,在这一过程中,作者采用了条件生成模块。

以多个代理为基础生成其未来轨迹,在分析其过去轨迹、静态场景上下文以及随机噪声输入后,能够创建出一系列随机输出轨迹。作者采用了鉴别器

来区分生成的轨迹是真实的还是假的(生成的)。

和

在编码部分与该段落中提到的确定性的模型共享完全相同的架构结构,并用于推理静态场景下的上下文信息以及多个代理之间的相互作用关系。

和

采用前面训练得到的确定性的模型来进行参数初始化过程。详细的体系结构和损失的具体描述如下:
生成器

通过给定的过去轨迹

和静态场景上下文

(其中

它代表的是第i个代理。它的结构与之前介绍的LSTM+CNN+UNet的MATF编码-解码器完全一致,但主要区别在于它在输出联合编码向量时采用了不同的策略。

后和一个随机采样高斯噪声向量

连接,连接后的向量作为LSTM的输入来预测未来轨迹

。
判别器

观察所有智能体的历史轨迹+生成轨迹组

和历史轨迹+真实轨迹组

,输出他们的真/假标签,如果轨迹为假输出标签

,否则

除了这些差异之外该结构与MATF encoder-decoder架构相似:(1)该系统采用过去及未来行为序列作为输入而非仅限于过去的序列;(2)作为一个分类器它不会通过LSTM处理最终代理状态

解码为未来轨迹。相反,最终的代理编码被输入到全连接层中,以便归类。
由生成对抗网络的损失函数

构成损失函数
。
