Discretization and Implementation of State Space Models

chensy 发布于 18 天前 30 次阅读


本篇包括经典SSM从连续形式到离散形式的参数变换过程推导,Mamba2的周围架构和部分实现细节,以及SSM到经典Linear Attention的对应关系。

免责声明:本篇包含大量标量数学公式在矩阵上的滥用,可能引起不适

Discretization of SSM

参考此处。Mamba/Mamba2的原型是一个时变系统,即

dSdt=A(t)S(t)+B(t)u(t),B(t)Rdstate,u(t)Rdhead.y(t)=C(t)S(t)+D(t).\begin{aligned} \frac{dS}{dt}&=A(t)S(t)+B(t)u(t), B(t) \in \mathbb{R}^{d_{state}}, u(t)\in \mathbb{R}^{d_{head}}. \\ y(t)&=C(t)S(t)+D(t). \end{aligned}

在推导时,我们将时间窗口固定到一个短区间[tk,tk+1][t_{k},t_{k+1}]中,使得A,B,C,DA, B, C, D四个矩阵保持近似恒定,从而转变成一个时不变系统。

接下来,我们求解第一式中的微分方程。我们构造α(t)S(t)\alpha(t)S(t),得到

d(α(t)S(t))dt=α(t)dSdt+S(t)dα(t)dt=α(t)(AtS(t)+Btu(t))+S(t)dα(t)dt=(α(t)At+α(t)dt)S(t)+α(t)Btu(t).\begin{aligned} \frac{d(\alpha(t)S(t))}{dt}&=\alpha(t) \frac{dS}{dt}+S(t) \frac{d\alpha(t)}{dt} \\ &=\alpha(t)(A_{t}S(t)+B_{t}u(t))+S(t)\frac{d\alpha(t)}{dt} \\ &=\left( \alpha(t)A_{t}+\frac{\alpha(t)}{dt} \right)S(t)+\alpha(t)B_{t}u(t). \end{aligned}

α(t)At+α(t)dt=0\alpha(t)A_{t}+\frac{\alpha(t)}{dt}=0

得到α(t)=CeAt\alpha(t)=Ce^{-At}。代回上式,得到(常数消去)

d(eAttS(t))dt=eAttBtu(t).\frac{d(e^{-A_{t}t}S(t))}{dt}=e^{-A_{t}t}B_{t}u(t).

在指定区间内,得到

eAttk+1Sk+1eAttkSk=tktk+1eAttBtu(t)dtSk+1=eAt(tk+1tk)Sk+tktk+1eAt(tk+1t)Btu(t)dt.\begin{aligned} e^{-A_{t}t_{k+1}}S_{k+1}-e^{-A_{t}t_{k}}S_{k}=\int_{t_{k}}^{t_{k+1}} e^{-A_{t}t}B_{t}u(t) dt \\ S_{k+1}=e^{A_{t}(t_{k+1}-t_{k})}S_{k} + \int_{t_{k}}^{t_{k+1}} e^{A_{t}(t_{k+1}-t)}B_{t}u(t) dt. \end{aligned}

Mamba1-2在此处求解时均使用了零阶保持方法(ZOH),说人话就是固定u(t)u(t)u(tk)u(t_{k}),当作常数求解。Mamba3在此处使用的为梯形法则。总之,就是需要借用数值求解的方法近似地算出这个积分。

使用零阶保持方法之后,我们得到

Sk+1eΔAkSk+(eΔAkI)Ak1Bkuk:=A¯kSk+B¯kuk.S_{k+1}\approx e^{\Delta A_{k}}S_{k}+(e^{\Delta A_{k}}-I)A_{k}^{-1}B_{k}u_{k}:=\bar{A}_{k}S_{k}+\bar{B}_{k}u_{k.}

这里我们同步修改了下标表示,以得到完整的离散化公式。最终,我们得到

A¯k=eΔAk,B¯k=(eΔAkI)Ak1Bk,C¯=C.\begin{aligned} \bar{A}_{k}=e^{\Delta A_{k}}, \bar{B}_{k}=(e^{\Delta A_{k}}-I)A_{k}^{-1}B_{k}, \bar{C}=C. \end{aligned}

这个结果和Mamba1中给出的结果应该是一致的。不过Mamba2的代码中并没有直接这种包含求逆的形式,而是使用了B¯k=ΔBk\bar{B}_{k}=\Delta B_{k}。我们可以对这个式子进行麦克劳林展开并舍弃高阶项,得到

(eΔAkI)Ak1(I+ΔAkI)Ak1=Δ.(e^{\Delta A_{k}}-I)A_{k}^{-1}\approx (I+\Delta A_{k}-I)A_{k}^{-1}=\Delta.

Implementation in Mamba2

Mamba2完整的计算流程在modules/mamba2.py以及ops/triton/selective_state_update.py中。后者的selective_state_update_ref中提供了纯torch的算法实现。

值得一提的是,与其他主流架构不同,Mamba2的周围架构是并行性质的,即attention与MLP并行执行。它的模块输入会一次性投射到MLP的输入、MLP的gate、SSM的uu(代码中的标记是x)、SSM的BB、SSM的CC、SSM的Δ\Delta和output gating这么多的量。Mamba2的ngroups参数用于实现类似GQA的操作,单个token的BBCC会被投射到[ngroups, dstate]的形状,随后expand到[nheads, dstate]。其他的部分与主流linear模型均相同,内部维护的state大小为[nheads, dhead, dstate]

Equivalence to Linear Attention Paradigm

SSM的核心出装是这两个式子:

Sk+1=A¯kSk+B¯kukA¯k=eΔAk,B¯k=ΔBk,C¯=C\begin{aligned} S_{k+1}&=\bar{A}_{k}S_{k}+\bar{B}_{k}u_{k} \\ \bar{A}_{k}=e^{\Delta A_{k}}, \bar{B}_{k}&=\Delta B_{k}, \bar{C}=C \end{aligned}

在实际的操作中,AA并不是一个时变的矩阵,即不是由当前的token决定的,而是一个可学习的parameter。这个parameter对于每个head而言是一个标量。实际的Δ\Delta对于每一个head而言也是一个标量(至少Mamba2里是),决定的方式是

dt = self.proj(u)
dt = F.softplus(dt + self.dt_bias)

B,CB, C这两个量都是由当前的token决定的。所以,实际上BB对应于key,uu对应于value,CC对应于query。抛去遗忘门的部分,Mamba2就是一个Hebb memory。

那么,现在SSM和标准的LA看上去似乎就只有遗忘门不一样了,但其实不然。事实上,GDN使用的遗忘门和Mamba2的初始化方式是完全一样的。他们使用的初始化过程均为

# Initialize log dt bias
dt = torch.exp(
	torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
    + math.log(dt_min)
)
dt = torch.clamp(dt, min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True

assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
A_log = torch.log(A).to(dtype=dtype)
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True

即定义了dt_biasA_log两个参数。后者实际上是将一个均匀分布映射到对数空间作为可学习的参数,使用时再映射回常规的空间,目前不是很理解为什么要在对数空间中学习。前者看一看实际上很容易理解,它限制了F.softplus(dt + self.dt_bias)的初始范围,即通过初始化来约束这个gate的范围。先log后exp的初始化保证了大多数值分布在下界附近,而不是上界附近。

在Mamba2中,遗忘门的决定方式是

A = -torch.exp(self.A_log.float())  # (nheads,)
...
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A)  # (batch, nheads, dim, dstate)
...
state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1"))  # (batch, dim, dstate

exp(-self.A_log.exp() * F.softplus(dt + self.dt_bias))。而GDN事实上使用了相同的方式来决定(https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/gated_deltanet.py#L270):

g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)

指数操作被移动到了kernel中(https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py#L104):

# [BK, BV]
if USE_G:
    b_g = tl.load(p_g).to(tl.float32)
    b_h *= exp(b_g)

因此,GDN虽然在范式上没有AAΔ\Delta这些量,但是决定遗忘门的方式实际上是借鉴了Mamba2的模式,是完全一样的。

但是话又说回来,GDN和Mamba2在门控上还是有一些细微的差别。GDN的输入门β\beta是一个完全独立的量,但是Mamba2的输入门仍然使用的是Δ\Delta。后者的遗忘门和输入门之间存在一定的耦合关系,即输入门越大,遗忘门越小(或者说当前的写入强度越大,对历史的遗忘程度就越高)。此外,Δ\Delta本身并不是一个[0,1][0,1]中的值,所以Mamba2的输入门是没有上界的。这构成了两个最主要的差别。

此作者没有提供个人介绍。
最后更新于 2026-01-09