本篇包括经典SSM从连续形式到离散形式的参数变换过程推导,Mamba2的周围架构和部分实现细节,以及SSM到经典Linear Attention的对应关系。
免责声明:本篇包含大量标量数学公式在矩阵上的滥用,可能引起不适
Discretization of SSM
参考此处。Mamba/Mamba2的原型是一个时变系统,即
在推导时,我们将时间窗口固定到一个短区间中,使得四个矩阵保持近似恒定,从而转变成一个时不变系统。
接下来,我们求解第一式中的微分方程。我们构造,得到
令
得到。代回上式,得到(常数消去)
在指定区间内,得到
Mamba1-2在此处求解时均使用了零阶保持方法(ZOH),说人话就是固定为,当作常数求解。Mamba3在此处使用的为梯形法则。总之,就是需要借用数值求解的方法近似地算出这个积分。
使用零阶保持方法之后,我们得到
这里我们同步修改了下标表示,以得到完整的离散化公式。最终,我们得到
这个结果和Mamba1中给出的结果应该是一致的。不过Mamba2的代码中并没有直接这种包含求逆的形式,而是使用了。我们可以对这个式子进行麦克劳林展开并舍弃高阶项,得到
Implementation in Mamba2
Mamba2完整的计算流程在modules/mamba2.py以及ops/triton/selective_state_update.py中。后者的selective_state_update_ref中提供了纯torch的算法实现。
值得一提的是,与其他主流架构不同,Mamba2的周围架构是并行性质的,即attention与MLP并行执行。它的模块输入会一次性投射到MLP的输入、MLP的gate、SSM的(代码中的标记是x)、SSM的、SSM的、SSM的和output gating这么多的量。Mamba2的ngroups参数用于实现类似GQA的操作,单个token的和会被投射到[ngroups, dstate]的形状,随后expand到[nheads, dstate]。其他的部分与主流linear模型均相同,内部维护的state大小为[nheads, dhead, dstate]。
Equivalence to Linear Attention Paradigm
SSM的核心出装是这两个式子:
在实际的操作中,并不是一个时变的矩阵,即不是由当前的token决定的,而是一个可学习的parameter。这个parameter对于每个head而言是一个标量。实际的对于每一个head而言也是一个标量(至少Mamba2里是),决定的方式是
dt = self.proj(u)
dt = F.softplus(dt + self.dt_bias)
而这两个量都是由当前的token决定的。所以,实际上对应于key,对应于value,对应于query。抛去遗忘门的部分,Mamba2就是一个Hebb memory。
那么,现在SSM和标准的LA看上去似乎就只有遗忘门不一样了,但其实不然。事实上,GDN使用的遗忘门和Mamba2的初始化方式是完全一样的。他们使用的初始化过程均为
# Initialize log dt bias
dt = torch.exp(
torch.rand(self.nheads, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
)
dt = torch.clamp(dt, min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(*A_init_range)
A_log = torch.log(A).to(dtype=dtype)
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
即定义了dt_bias和A_log两个参数。后者实际上是将一个均匀分布映射到对数空间作为可学习的参数,使用时再映射回常规的空间,目前不是很理解为什么要在对数空间中学习。前者看一看实际上很容易理解,它限制了F.softplus(dt + self.dt_bias)的初始范围,即通过初始化来约束这个gate的范围。先log后exp的初始化保证了大多数值分布在下界附近,而不是上界附近。
在Mamba2中,遗忘门的决定方式是
A = -torch.exp(self.A_log.float()) # (nheads,)
...
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
...
state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
即exp(-self.A_log.exp() * F.softplus(dt + self.dt_bias))。而GDN事实上使用了相同的方式来决定(https://github.com/fla-org/flash-linear-attention/blob/main/fla/layers/gated_deltanet.py#L270):
g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
指数操作被移动到了kernel中(https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py#L104):
# [BK, BV]
if USE_G:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
因此,GDN虽然在范式上没有和这些量,但是决定遗忘门的方式实际上是借鉴了Mamba2的模式,是完全一样的。
但是话又说回来,GDN和Mamba2在门控上还是有一些细微的差别。GDN的输入门是一个完全独立的量,但是Mamba2的输入门仍然使用的是。后者的遗忘门和输入门之间存在一定的耦合关系,即输入门越大,遗忘门越小(或者说当前的写入强度越大,对历史的遗忘程度就越高)。此外,本身并不是一个中的值,所以Mamba2的输入门是没有上界的。这构成了两个最主要的差别。

Comments NOTHING