本文是对 Delta Rule 背后的数学推导的重新梳理。其实 DeltaNet 原文已经讲得很清楚了,但是有点乱,在这里从我的思路来重新写一遍如何从原始的递归形式推导到全并行的矩阵形式。
下文向量表示均为列表示。
正文
我们的起点是递归形式的表达式:
St=St−1(I−βtktktT)+βtvtktT.
事实上,一切并行化的依据都还是最原始的 linear attention 形式,即
ot=i=1∑tvikiTqt=Stqt
具有等价的并行形式
O=(QKT⊙M)V.
这就意味着St必须要写成∑i=1taibiT的形式。为了向这个形式转换,我们假设对于 Delta Rule,有
St=i=1∑tuikiT
这样的形式,然后尝试求出ui。为了求出ui,我们可以从 Delta Rule 的状态演化方程中得到:
St−St−1=βt(vt−St−1kt)ktT=utktT.
注意到,我们在这里已经利用了βt是标量的性质,所以当βt是向量的时候,这里的推导可能是需要重新考虑的(比如 RWKV7)。于是我们得到
ut=βt(vt−St−1kt).
这里同时消掉了两边的ktT,考虑两侧逐元素相同,这是显然的。于是,我们就解得了需要的ut。但是这个表达式仍然和St−1相关,看上去不是那么好看,我们可以再次利用St−1=∑i=1t−1uikiT,得到:
ut=βt(vt−i=1∑t−1uikiTkt).
如此,我们就获得了关于ut的递推式。于是,利用原始的 linear attention 模式,有
O=(QKT⊙M)U.
接下来的事情就转变成如何求得U的矩阵表达。把ut的递推式移项,得到
i=1∑t−1βt(kiTkt)ui+ut=βtvt.
我们可以看到,这个式子实际上是对{u1,u2,...,ut}的一个线性组合。我们还可以把t换成{t−1,t−2,...,1},并且把这一系列线性组合写成矩阵形式,得到
⎝⎜⎜⎜⎜⎛1β2k1Tk2β3k1Tk3⋮βtk1Tkt1β3k2Tk3⋮βtk2Tkt1⋮…⋱…1⎠⎟⎟⎟⎟⎞U=diag(β)V.
左边这个矩阵看着比较锉,但是如果把对角线元素换成{β1kiTk1,β2k2Tk2,...,βtktTkt},这个矩阵就是diag(β)KKT。然后我们使用下三角掩码M−I处理他,再把单位矩阵加回去即可。所以
(I+(diag(β)KKT)⊙(M−I))U=diag(β)V.
接下来,我们求一次逆,就得到U的表示了。这个矩阵是主对角元素都不为零的下三角矩阵,因此可逆。我们就得到了
U=(I+(diag(β)KKT)⊙(M−I))−1diag(β)V.
于是,我们就得到了完整的矩阵表达:
O=(QKT⊙M)(I+(diag(β)KKT)⊙(M−I))−1diag(β)V.
log-linear attention 给出的结果中,在最后面还加了一个M。这是不必要的,因为下三角矩阵的逆仍然是下三角矩阵,且两个下三角矩阵的乘积仍然是下三角矩阵。计算上,三角矩阵的逆是一个相对简单的迭代过程,详细可以参考这里。
这是完全矩阵化的表达,更常用的 chunk 表达思路是一样的。chunk 表达实际上可以看作一个初始状态不为零的过程,会在以上的过程中引入额外的两项和一个类似与ut的变量。另外会多出St自身按照 chunk 单位更新的过程,见原文的 equation 6-9。
扩展:DeltaFormer
写到这里的时候发现把 DeltaFormer 顺便推导了是很容易的事情,就顺手写一下。DeltaFormer 实际上就是在 Delta Rule 上使用 exponential kernel;对比之下,softmax attention 实际上就是在 Hebb Rule 上使用 exponential kernel。状态更新如下:
St=St−1(I−βtϕ(kt)ϕ(ktT))+vtϕ(ktT).
原文没有加学习率,实际上也可以加。这个形式由于带着无穷维的 kernel,所以没有计算意义。整个推导过程和上面基本一致,只不过现在我们假设St=∑i=1tuiϕ(kiT)。我们也会得到一个变量ut:
ut=vt−i=1∑t−1uiϕ(ki)Tϕ(kt)=vt−i=1∑t−1uisoftmax(kiTkt).
读出过程自然地变成:
ot=Stϕ(qt)=i=1∑tuiϕ(kiT)ϕ(qt)=i=1∑tuisoftmax(kiTqt).
因此不可计算的东西就奇迹般地消掉了。对比 Transformer,其实可以发现唯一的区别就是把V换成了U,引用作者的话说:
相比于Transformer和Linear attention无视之前的状态,无脑写入或者append key和value的做法不同,Delta rule会考虑每次写入的时候,根据之前的状态进行修改。这个事情在上个世纪研究的还挺多的,包括Schmidhuber[1], Sutton[2]还有Hinton[3], 虽然那个时候的名称叫做fast weight programming,但内核是一致的。2021年Schimidbuber还重提了一次[4]。
这里的softmax的写法是一个很不标准的写法,分母的部分可以自行意会一下。另外,softmax并不一定是最优的kernel。上面的式子中有两个kernel,文中对这两个kernel的作用也进行了ablation。第二个kernel,也就是传统的attention中的kernel,被证实与retrieval能力密切相关;非线性越强的kernel将会导致更强的retrieval能力,文中尝试了Linear/Round/ReLU/Softmax四种,后两者的MQAR retrieval能力显著更好,甚至ReLU比Softmax要好。值得注意的是,ReLU也不是一个可解耦的kernel,所以这个结果也有一定的道理。第一个kernel则和state tracking能力密切相关,文中理论证明了deltaformer在所有的key之间距离足够远时具有完整的state tracking能力。state tracking的定义是给定序列[(k_{1}, v_{1}), (k_{2}, v_{2}), \dots, (k_{n}, v_{n})],在t≥n+1时,通过给定某个(kt,vt),能够交换kt1和kt2目前关联的值,从而使得使用qt=kt1能够取出ot=vt2,反之亦然。在证明时,文中使用的两个kernel均具有向下取整的性质,但在实验上仍然使第二个kernel取softmax以对应经典attention。
另外,如果读者读过 Memory Mosaics 这篇文章,或许会觉得第一个递推式与之有些神似。可能存在某种奇妙的关联?
讨论
(一些杂谈)
attention 和 linear attention 的本质差距在哪里?
我一直觉得这是一个很微妙的话题。因为,从 softmax attention 到 linear attention 仅仅是去掉了一个 softmax 操作,就发生了巨大的性质上的变化。softmax attention 和 linear attention 被解读的视角实际上是不一样的,我们解读 attention 最常用的视角是 attention map,而解读 linear attention 最常用的视角是联想记忆。从 attention map 的角度来看,差距在于概率性质吗?并不是,linear attention 也可以使用恒正的 kernel,也可以有归一化。差距在于“全序列的可见性”吗?这个定义并不好,linear attention 也有QKT的 map,这不算是一种全序列可见性吗?
现在看来,更统一的视角可能在于联想记忆。softmax attention 实际上从未引入过常数以外的推理时空间开销,KV Cache 只是为了降低计算量采取的工程措施。如果愿意的话,attention 完全可以在一次计算中以纯 RNN 的方式计算;换言之,attention 与 RNN 相比,没有使用更多的空间,而是使用了更多的计算和对序列的多次阅读。其中的关键区别在于,写入联想记忆的 key 是否与当前的 query 相关。如果相关,就会导致各种平方的性质;如果不相关,就可以被线性化。
Comments NOTHING