SMM在机器学习和深度学习中的应用
状态空间方程
状态空间方程详细阐述了一种如何将系统的数学模型转化为一组一阶微分方程的方法,并对系统状态的演变过程及其输出方程进行了深入分析。
由两个主要方程组成:
状态方程 :
阐述系统状态随时间演变的过程。在线性系统中,则可将其数学表达式写作:
\dot{x}(t) = A x(t) + B u(t)
其中符号\dot{x}(t)代表状态向量x关于时间t的变化率;矩阵A代表系统的内在动态关系;而矩阵B则反映了外部输入变量u(t)对系统的影响机制。
Output Equation :
This equation mathematically describes the relationship between a system's output and its state. For linear systems, the output equation can be expressed as:
y(t) = Cx(t) + Du(t)
where y(t) represents the system's output vector, C denotes the output matrix that maps the state variables to the outputs, and D is the direct transmission matrix that captures the direct influence of inputs on outputs.
状态空间方程在机器学习和深度学习中的应用
循环神经网络(RNNs)
RNN的核心思想与_state_space_model_密切相关,在_RNN_中。系统在处理序列数据时的状态特征表现为动态系统的行为特征,并由其权重参数决定其动态行为的变化规律及信息传递机制。具体而言,在不同时间点上系统的隐藏状态可被看作是其内部动力学行为的表现形式,并通过数学模型的方式构建起从当前输入到下一时刻信息传递的知识体系框架
在基本RNN单元中,在时刻t时输入xtx_t以及上一时刻的状态ht−1h_{t-1}都会影响当前状态值ht,h_t的变化情况。其状态更新方程可表示为:其中WhhW_{hh}和WxhW_{xh}分别代表隐藏到隐藏以及输入到隐藏的权重矩阵,bhb_h是该层的偏置项
生成yty_t于时刻tt基于当前时刻的状态hth_t利用一组新的权重WhyW_{hy}和偏置byb_y进行计算
序列模型和时间预测
序列模型中的状态空间方程:
状态方程:
xt+1=FtXt+Gtut+wt x_{t+1}=F_tX_t+G_tu_t+w_t
观测方程:
yt=Htxt+vty_t=H_tx_t+v_t
SNN中的状态空间方程
τmdV(t)dt=−[V(t)−Vrest]+RI(t) τ_m\frac{dV(t)}{dt}=-[V(t)-V_{rest}]+RI(t)
- V(t) 是在时间 t 上的膜电位。
- \tau_m 决定了该神经元膜电位的变化速度。
- V_{\text{rest}} 代表无外界刺激时神经元的静息状态。
- R_m 代表神经元细胞膜的总电阻值。
- I(t) 代表作用于神经元的动态输入电流强度。
当神经元膜电位达到阈值时(即V(t)\geq V_{\text{thresh}}),该神经元触发动作电位并将其膜电位归零;这一过程可用以下条件和重置方程描述:
深度序列模型
基于序列数据的深度学习架构可被视为一种通过循环神经网络、卷积神经网络或自注意力机制等基础模块构建的序列到序列转换方法。
定义 1.1(非正式)。作者采用了序列模型来代表在序列 y = f_θ(x) 上的参数化映射关系。其中,特征向量序列 x 和 y 都属于 R^D,并且长度为 L。θ 通过梯度下降算法被确定下来。

例如RNN、CNN、Transformer等都属于深度序列模型;这类模型普遍面临诸多挑战。其中RNN的收敛速度较慢且容易出现梯度消失的问题;CNN侧重于局部上下文处理但可能导致序列推理成本过高且上下文长度受限;Transformer在处理长序列时会呈现二次增长的问题(即n²d复杂度);神经 ordinary differential equations (NODE)虽然理论上能够解决连续时间建模与长距离依赖问题但其计算效率相对较低。
所以深度序列模型面临几大挑战:
通用能力
RNN:必须迅速更新其隐藏状态的状态配置包括实时处理的任务以及强化学习。
CNN:对音频、图像和视频等均匀采样的感知信号进行建模
Transformers:对语言等领域中密集、复杂的交互进行建模
NODE:处理非典型世时间序列设置,如缺失或不规则采样数据。
计算效率
在训练阶段, 通常可用整个输入序列的损失函数来表示其目标;而在推理过程中, 模型的配置可能会有所变化;例如, 在在线处理或自回归生成的情况下中, 输入仅显示单个时间步, 模型需能够依次高效处理这些输入.
RNN作为一种序列型模型,在当前主流的GPU与TPU等先进硬件上难以实现高效的训练效率;相比之下,
传统的CNN架构与Transformer架构不具备高效的自回归推理能力,
因为它们本质上都是基于无状态的设计;
此外,
每个新增输入所需的时间成本可能与其所依赖的整个上下文信息量呈正相关;
值得注意的是,
更为复杂的模型架构可能能够提供额外的功能特性,
然而其计算复杂度随之增加可能导致运算速度变慢;
例如某些高级模型可能需要调用高成本的微分方程求解器来完成复杂的运算。
长程依赖
现实世界中的困难可能源于难以捕捉数据间的相互作用;同样地,这也可以归因于优化问题中的梯度消失现象。
状态空间序列模型(SSM)
SSM 被定义为一种简单的序列模型,并基于一个隐式的潜在状态将一维函数或序列进行映射:


SSM是一种基础且简单的模型,并具备多种特性。这些模型与NDE、RNN以及CNN等类别紧密相关,并且能够灵活多变地实现通常需要专门模型才能完成的任务(挑战一)。SSM由状态方程和观测方程组成;基于前一步骤推导的状态方程可结合RNN结构进行设计;观测方程则描述了系统状态如何转化为可测量的数据,并与其相似之处在于采用CNN式的局部捕捉机制,并具有一定的局部捕捉能力。
- SSM 具有连续性特征。其数学本质是微分方程,在实际应用中可应用于模拟连续过程、处理缺失数据以及适应不同采样率的需求。
- SSM 具有循环特性。通过标准技术将其转换为线性递推关系,在推理过程中模拟成状态循环模型结构,并保证每个时间步的内存和计算量保持一致。
- SSM 属于卷积系统的范畴。其本质是线性时不变系统,并可明确表示为其连续卷积形式,在研究离散时间版本时发现可通过离散卷积实现并行训练以显著提升效率。
由此可见,SSM 属于一类通用的序列模型,在多种环境下均展现出很强的能力,并覆盖了包括语音信号、图像数据以及时间序列等多种领域。
SSM 的通用性并非没有局限性。原始 SSM 面临着两个主要挑战——可能相较于其他模型更为显著——这使得它们难以作为高效的深度序列模型得到广泛应用。这些挑战主要包括:(1)与同样规模的 RNN 和 CNN 相比,SSM 效率较低;(2)它们在处理具有长依赖关系的任务时会遇到困难,并继承了 RNN 遗传的梯度消失问题。
利用结构化 SSM 进行高效计算(S4)
考虑到状态变量x(t)∈R^N在计算资源和内存使用方面存在过高的消耗(挑战二),通用的 SSM 无法应用到深度序列模型中。
考虑 SSM 中的状态维度为 N 以及序列长度为 L,则仅需计算完整潜在状态 xx 即需进行 O(N²L) 次运算及占用 O(NL) 空间——与其总输出量下限 Ω(L+N) 进行比较。由此可见,在处理具有约 N ≈ 100 的规模时
S4的前身
Hippo:假设 t0t_0时刻我们看到了信号 u(t)u(t)的之前部分:
为了以在有限内存预算内缩减前面这一段输入数据来提取特征信息为目标,在探索如何有效利用内存资源的过程中

当我们在接收更多的 signal 时,在现有的内存预算范围内致力于压缩这一整段 signal,并将该多项式的各项系数进行相应更新(如上图所示)。

以上,会涌现出两个问题:
1. 如何找到这些最优的近似?
2. 如何快速地更新多项式的参数?
为了应对这两个问题, 我们需引入一个度量标准, 以量化评估这个系统的性能优劣. 比如, 可以采用EDM方法.
由此提出了HiPPO(High-order Polynomial Projection Operator)的正式定义式:该算子由两组信号与两组矩阵共同作用所形成的。

HiPPO相当于将函数映射到函数,这里给个通俗的例子解释一下:

类似地,在这一节中,请注意以下几点:其中u代表原始信号序列(Signal),而x代表被压缩后的信号序列(Compressed Signal)。当u持续增长时,在线更新算法则允许在线更新被压缩后的信息流x。具体而言,在这种情况下(当使用64个units(每个units能够完全表示1万单位的信息)时),我们发现通过多项式编码器实现的有效性很高,并且成功保留了大量原始信息:

其中红色线条等价于输入信息的一种重建方式(可以看出,在离当前时间最近的那个时刻上进行记录的信息最为精准;而对于时间跨度较大的历史点,则信息的表现不够精准)。这里需要注意的是,在HiPPO模型中仅需关注当前时刻所使用的多项式(polynomial)参数以及在此之前的时间点上进行记录的信息u即可;无需关心之前各个时间段内使用的多项式参数。
都应用EDM这一measure。然而,在学习过程中我们通常会涉及不止一个measure(例如时间变化的measure可能会随着时间而变化)。那么如何进行建模呢?最终得出结论:HiPPO能够在多种不同的measure方面成立。

S4的推出
我们正式定义下S4
在开始部分采用state space model系统,并将其简称为SSM。此外,在下图所示的两个方程中引入特定的矩阵值。

- 学习对应的参数
S4的性质:连续的表示、用Recurrent快速infer、用Convolutional快速训练

第一个特性表现为持续性表达。值得注意的是即使基于离散数据构建起来的方法依然能够捕获底层隐藏着持续性的信息因为从SSM的角度看在其中(sequence)不过是由continuous signal(信号)通过采样得到的一种discrete形式(即序列)。或者可以理解为持续性信号模型是对离散序列模型的一种概括

第二个特性是高效的在线计算能力,在此之前HiPPO也提及过相关内容。具体而言,在下一时刻确定状态x′x'时,仅需当前状态xx的信息以及全局输入uu。

尽管要求全局输入,但这种全局计算的时间复杂度为O(1),这一特性与RNN一致,而与Transformer和CNN不同
其时间复杂度为恒定,并与RNN模型具有相同的特性。由于存在中间状态(如图中的蓝色线条),仅依赖于前一状态以及全局输入。
SSM的一个缺点在于,在预知未来信号的情况下进行训练会导致效率低下。是否有一种方法能够并行化SSM?作者提出了一种创新的方法:采用一个卷积核KK来直接从输入uu连接到输出yy(而不是经过中间态xx)。

选择机制的SSM算法(S6)
我们认为序列建模中的一个根本性挑战在于如何将复杂多样的上下文信息浓缩为更为紧凑的状态(We believe that a fundamental challenge in sequence modeling lies in effectively compressing contextual complexity into more compact representations)
从这一角度看,在关注点上其效果虽有亮点但整体效率仍显不足。因为必须显式存储整个上下文(即KV缓存),这直接导致训练与推理时所消耗的算力较大
Take for instance, attention exhibits a dual nature, being both highly effective yet also inefficient, as it fails to employ explicit compression techniques. Such as this illustrates, auto-regressive inference necessitates explicit storage of the entire context (i.e., KV cache), resulting in slow linear-time processing and quadratic-time training for Transformers.
例如,在处理每个输入时
RNN的推理和训练效率高,但性能容易受到对上下文压缩程度的限制
Another point to consider is that recurrent models are effective because of their finite memory capacity, enabling them to perform tasks with constant computational overhead. However, this limitation constrains their ability to capture long-range dependencies within sequential data.
好比,RNN每次只参考前面固定的字数,写的快,但容易忘掉更前面的内容
Mamba机制通过主动筛选信息并着重处理关键内容,并可选择性地跳过不相关的部分。即使在状态大小固定的情况下也能有效压缩上下文信息。
比如,在Mamba的学习过程中,在每一次参考之前的所有内容时,在每一次深入学习的过程中,在后面的内容中对前面的内容进行更加深刻的总结和提炼。
总体而言,在序列模型中实现效率与效果之间的平衡问题主要取决于它们如何处理状态压缩的问题。在这一过程中,高效的模型需要设计一个简化的状态表示方式,并且能够通过有限的状态捕捉到足够的上下文信息;而有效的模型则需要确保其状态不仅能够反映当前的信息还能够保留来自上下文的所有必要关联性。为了在两者之间取得更好的平衡,mamba采用了选择性地关注关键信息并过滤掉非关键细节的方式
在其前身结构化状态空间模型S4中,其有4个参数(∆,A,B,C)(∆, A, B, C)

且S4是LTI系统,不随输入变化,这些参数控制了以下两个阶段:

第一阶段(1a 1b),通常会运用固定的公式A = f_A(Δ, A)以及B = f_B(Δ, A, B)。这些公式会被用来将"连续参数"(Δ,A,B)转换为"离散参数"(A,B)。在这里面,(f_A,f_B) 被视为离散化规则,并且可以通过多种不同的方法来实现这一转换过程。例如,在下面所述的方程组中所定义的零阶保持法(ZOH)就可以作为其中一种实现方式

第二阶段(包括2a、2b以及3a、3b),当参数从(∆, A, B, C)转换为(A, B, C)时,在模型中能够以两种不同的方法进行计算:一种是采用线性递归的方式(编号2),另一种则是使用全局卷积的方法(编号3)。
模型通过卷积模式(3)能够快速实现并行化训练(整个输入序列能够预见到),然后转而采用循环模式(2)以快速完成自回归推理(仅能观察当前单个时间步)。
为什么能实现高效的并行处理呢?因为这种模式避免了涉及状态计算,并实现了只依赖于(B,L,D)参数的卷积层(3a)。

每个矩阵都可以由N个数字来表示,在处理一个批量大小为BB、长度为LL且具有DD个通道的输入序列xx时,每个通道上都独立地应用了SSM
考虑到在当前情境下, 每个输入的总隐状态由DNDN维构成, 在基于序列长度进行计算时所需的时间和内存资源为O(BLDN).
各个变量含义:
Δ是一个标量相当于遗忘机制的一种表现。
该量与RNN中的gating模块具有紧密关联其中data-dependent的Δ其功能与神经网络中的遗忘门单元相似。
该标量与RNN中的gating单元存在密切联系其特性尤其体现在数据依赖性方面类似于神经网络中的遗忘门单元
BB,起到的作用类似于:进RNN的memory
CC, 承担的作用如同:从RNN中提取记忆 所以有人认为,data-dependent 的B/CB/C模块与其在RNN中的input/output gate具有相似的功能
AA,这表明针对该维度的SSM,A在各个隐藏状态维度上具有不同的影响,并实现多尺度和精细级联门控机制的作用,其原因在于LSTM网络采用逐元素乘法操作。
而在Mamaba中,作者通过将这些参数B、C、ΔB、C、\Delta指定为输入函数,并使模型能够根据输入内容动态地调整其行为模式。

从S4到S6的过程中, 可以看出BB、CC在数值上经历了变化, 其具体表现为: BB与CC分别从原先为(D,N)(D,N)提升至(B,L,N)(B,L,N), 而Δ\Delta值则相应地由原先的DD降至(B,L,D)(B,L,D). 进一步而言,




来逐一将B,C,ΔB,C,\Delta数据依赖(data dependent)化
至于上面的所谓

通过一个线性层将输入向量xx映射到d维空间
而通常情况下 NN与SSM之间的隐藏层维度(hidden dimension)相对较小
同时 在每个位置上使用的参数(如B C Δ等)各不相同 特别地 在S4阶段时所有位置共享这些参数
尽管A仍未成为data dependent;然而经过state space model进行离散化处理后
通过outer product运算将数据转换为(B,L,N,D)的data-dependent tensor,并以参数效率高的方法实现data-dependent的目的
神经常微分方程(NODE)
ResNet
神经网络经常微分方程的动机来自于ResNet。ResNet能够通过一个函数来表示非线性函数、权重矩阵、偏置以及残差连接。
ht+1=f(ht,θt)+ht
h_{t+1}=f(h_t,θ_t)+h_t
将相邻网络层之间的间距缩小至无极限的小数值时,则可将ResNet重构为一个连续型的神经网络体系。此操作使得我们可以对ResNet的离散层与其对应的连续型神经网络表示进行对比分析。观察发现,在连续型神经网络中状态的变化速率由若干非线性函数所决定,并且这些函数在时间维度上保持恒定特性;这与微分方程的形式存在高度相似之处。
h_{t+1} - h_t = f(h_t, θ_t)
h_{t+1} - h_t ⋅ 1 = f(h_t, θ_t)
h_{t+∇} - h_t ⋅ ∇ | _{∇=1} = f(h_t, θ_t)
lim ∇→0 (h_{t+∇} - h_t)/∇ | _{∇=1} = f(h(t), t)]
dh(t)/dt = f(h(t), t, θ)]
ResNet:离散和连续
比较离散和连续的ResNet:
离散:
ht+1=f(ht,θt)+hth_{t+1}=f(h_t,θ_t)+h_t
- LL个离散层
- 潜在状态以离散方式改变
- 潜在状态动态由LL个函数控制
连续:
dh(t)dt=f(h(t),t,θ)\frac{dh(t)}{dt}=f(h(t),t,θ)
- 无限层
- 潜在状态以连续方式改变
- 潜在状态动态由一个函数控制
RNNs
该方法具有显著效果。在RNN架构中,下一个时间步的状态通过一个确定性的数学关系式进行计算。这一计算步骤结合了非线性激活函数(如tanh)的作用,并融合了当前输入与前一层隐藏状态之间的线性变换信息以及当前输入与偏差向量之间的相互作用。假设前一层隐藏状态采用恒等变换且激活函数为向量1,则上述关系式可重新表述为微分方程形式。从而使得神经网络能够有效地处理基于时间的信息。相比之下,在处理连续时间序列时,常微分方程架构通常优于基于离散时间步骤的RNN模型。

神经ODE
在时序神经常微分方程(SNN-ODE)中
- 我们面临一个常微分方程(ODE)求解问题。
- 对于方程 \frac{dh(t)}{dt} = f(h(t), t, \theta) ,我们目前无法直接求解其解析解。
- 通过结合神经网络和反向传播算法(backpropagation),我们的目标是从给定数据中学习微分方程的时间导数 \frac{dh}{dt} 。
神经ODE:前向传播
我们目前致力于探讨神经ODE(神经常微分方程)的正向传递过程。输入状态由h_0代表初始时刻的状态。
状态动态或导数函数可以由神经网络来表示。一般来说,在这类模型中通常具有1至2个隐藏层的结构。
状态动态:dh(t)dt=f(h(t),t,θ)\frac{dh(t)}{dt}=f(h(t),t,θ)。
将输出状态视为基于常微分方程数值解器的时间积分过程。该数值方法通过设定一系列离散的时间点来进行计算,在每个时间段内采用特定算法逐步推进系统状态的变化。其参数包含初始状态、动力学函数及其相关参数、起始时间和终止时间等信息。\n自适应步长求解器与固定步长求解器不同之处在于无需预先指定时间间隔(∇t)。为了实现这一目标,我们需要对神经网络所代表的动力学方程进行训练和更新。

ODEnet 在前向传播中有两个核心特点:
模型的深度 :
由于 ODEnet 是连续的,它没有明确的层级数。在这篇论文中,作者使用 ODE Solver 评估的次数来代表模型的“深度”。
深度与误差控制的关系 :
在 ODEnet 中的"深度"与其设置的 error tolerance 呈正相关关系。当设定较低的 error tolerance 时,在求解过程中需要执行更多的 computation steps 来保证 solution accuracy, 这将导致 model overall computational complexity 的提升。这一特性为我们提供了一种权衡建模精度与计算效率的有效手段, 使我们能够在 accurate 和 efficient 之间找到平衡点, 实现 desired model performance without excessive computational resource consumption.

论文中的图示阐明了以下点:
- a 图:当误差容忍度降低时(或随着误差容忍度减小),前向传播阶段的任务函数评估次数上升。
- b 图:通过对比分析了不同计算模式下的评估次数与计算时间之间的关系。
- c 图:结果显示,在前向传播阶段的任务函数评估次数大约是反向传播阶段的一倍半左右(或两倍)。这表明 adjoint sensitivity 方法在内存和计算效率方面具有显著优势(因为其无需逐一分析前向传播过程中的每一个步骤)。
- d 图:随着训练过程的发展(或持续),任务函数在各阶段的评估次数逐渐上升(或持续增长),这表明模型的整体复杂性也在逐步提升。
总结而言,在准确性和计算成本之间寻求平衡的ODEnet方案不仅具备这一核心优势,并且还展现了其独特的性质以及高效的性能特征
反向传播
关于反向传播的主要困难在于如何将梯度传递给 ODE Solver。一种直接的方法是通过逆向利用前向传播的计算路径来传递梯度;然而这种方法在内存需求和数值精度方面均存在较大局限性。因此研究者采用了将 ODE Solver 视为不可穿透模块的一种策略;无需(或极为困难)直接传递梯度;转而采用另一种方式‘规避’这一问题
这种“绕过”策略被称为 adjoint method。在反向传播过程中,模型利用扩展的ODE求解器来计算梯度. 该方法在计算效率和内存消耗方面表现出色,并且能够精确地控制数值误差.
具体而言,若我们的损失函数为 L(),且它的输入为 ODE Solver 的输出:

我们需要计算L对z(t)的导数,并探讨模型损失的变化与隐藏状态z(t)之间的关系。在这一过程中涉及到了L对z(t₁)这一变量的导数值,在后续运算中起着关键作用。为了便于理解,在此定义adjoint量a(t)=−∂L/∂z(t),该量实际上反映了隐藏层梯度的信息
在链式法则框架下的传统反向传播机制中,在每一层神经网络之间都需要从前一层到后一层依次计算梯度以完成误差传递。然而,在连续化的 ODEnet 框架下,则要求我们针对连续时间变量 t 计算某个特定函数 a(t),这个函数表示损失函数 L 对于隐藏状态 z(t) 的变化率。有趣的是,在这种情况下所涉及的概念与传统的链式法则具有高度相似性。通过以下公式可以看出 a(t) 关于时间 t 的导数计算结果如何影响整个系统的梯度传播过程

当我们获得了每个隐层状态对应的梯度时

综上,具体过程如下:
- 确定损失函数梯度的起始点:对于损失函数 L() 和它的输入(即 ODE Solver 的输出)z(t),我们首先计算损失 L 对 z(t) 的导数,在整个梯度计算中确定了一个起点。这个导数称为 adjoint a(t),相当于隐藏层的梯度。
- 分析连续时间模型中梯度随时间的变化情况:在连续时间模型 ODEnet 中,梯度沿时间连续传播。a(t) 的导数描述了这种梯度如何沿时间变化。
- 计算网络参数相对于损失函数的导数:在计算出每个隐藏状态的梯度后,我们再对网络参数求导。因为时间"层级"是连续的,所以需要进行积分操作来得到损失对网络参数的导数。
网络对比
1. ResNet
-
核心组成部分:
-
f:对应于卷积层,在此语境中f为卷积操作所对应的数学表达式。
-
其中h为上一层输出的特征图,
-
t表示当前卷积层的序号。
-
ResNet:基于残差学习的方法,
-
通过引入跳跃连接实现了更深网络的学习能力,
-
该架构由T个残差模块组成。
- 伪代码:
def f(h, t, θ):
return nnet(h, θ_t)
def resnet(h):
for t in [1:T]:
h = h + f(h, t, θ)
return h
2.ODEnet
-
主要组件:
-
f被定义为神经网络模型,在这里θ被视为一个整体参数,并且同时作为独立的输入也被引入到网络中。这种设计表明整个网络采用了连续层次结构。
-
ODEnet无需采用离散层循环结构,而是通过求解常微分方程(ODE)在时间点t₁处的状态h来实现。
- 伪代码:
def f(h, t, θ):
return nnet([h, t], θ)
def ODEnet(h, θ):
return ODESolver(f, h, t_0, t_1, θ)
为了对这两种网络进行进一步比较研究, 研究团队在MNIST数据集上进行了相关实验. 通过实验分析, 在MNIST数据集上评估了这两种架构的表现特征包括模型性能、参数规模以及运行内存需求等关键指标.

NODE对ResNet的优点
- 内存优化策略
- 自适应计算过程中通过明确地调整数值积分的精度水平, 可以实现对模型运行速度与计算精度之间的有效平衡
- 连续标准化流程
- 时间序列分析中常遇到的非均匀采样问题

