Optimizer as Associative Memory

chensy 发布于 6 天前 9 次阅读


这是摘自Nested Learning中的片段。尽管整篇文章依旧很难评价,并且长篇的理论说明中不乏左右脑互搏和理论滥用的成分,但这个推导还是有一些有趣之处的。只不过这个结果应该如何理解,似乎还有些一言难尽。

本篇中会尽量采用严格的数学记号(我也有这一天),不加下标的表示变量,加了下标的表示具体的值。

Gradient Descent / SGD

因为容易混淆,所以先重新陈述一下DL中梯度下降的定义。数学上的概念很简单,对于一个f(w)f(w),为了找到它的极小值点,我们沿着梯度的相反方向不断地更新当前持有的坐标wiw_{i},使其最终到达极小值点。对于DL而言实际上也是一样的,只不过这里的ff替换成了损失函数,并且要求出损失函数当前的值,我们理论上需要遍历全部的数据。换句话说,在理论的状态下,我们有一个数据区域xDxx\in D_{x},并且对每一个xx都有一个对应的真值yy,我们希望通过一个参数化的函数去逼近真实的映射,即

fw(x)y.f_{w}(x)\approx y.

为了实现这个目标,我们定义损失函数L\mathcal{L}

L(w)=Dxfw(x)y2dx.\mathcal{L}(w)=\int_{D_{x}}||f_{w}(x)-y||^2dx.

这个函数对于ww是可微的,因此我们理论上能够对其进行梯度下降。但是,在实际操作中,我们往往没有能力遍历完整的数据分布去评估这个函数,甚至于我们根本不知道完整的数据分布;因此,我们使用少部分的(x,y)(x,y)对构成的子集去对L\mathcal{L}进行替代。极端的情况下(online learning,单样本),我们仅使用一个样本对(xi,yi)(x_{i}, y_{i})去构成近似的评估LiL\mathcal{L}_{i}\approx \mathcal{L}。我们在这里将SGD中的损失函数视作对理想的损失函数的近似,这可能不完全正确,不过不影响后面的理解。

从这个角度来看,参数(权重)的更新步数和在线学习中样本到达的步数完全是无关的。我们完全可以使用同一个样本估计损失函数,然后多次更新权重;也完全可以积累多个样本,再使用它们作为估计更新一次权重(实际上就是batch形式)。但在下面的推导中,我们默认每到达一个样本就更新一次权重,从而使它们使用完全相同的下标。这不是必要的,只是为了和linear attention保持一致。

From SGD to Associative Memory

现在假设我们有参数WW,它完成的运算是yi=Wxiy_{i}=Wx_{i},现代网络中的大部分参数都是这样。按照每到达一个样本就更新一次权重的SGD规则,我们写出它的更新规则:

ΔW=ηLi(W)W.\Delta W= -\eta\frac{\partial \mathcal{L}_{i}(W)}{\partial W}.

其中

Li(W)W=L(W,xi)W=L(W,xi)yiyiW=Li(W)yixiT.\frac{\partial \mathcal{L}_{i}(W)}{\partial W}=\frac{\partial L(W, x_{i})}{\partial W}=\frac{\partial L(W, x_{i})}{\partial y_{i}}\cdot \frac{\partial y_{i}}{\partial W}=\frac{\partial \mathcal{L}_{i}(W)}{\partial y_{i}}\cdot x_{i}^T.

LL是我们真正计算时使用的损失函数。代入具体的值WiW_i,我们得到

Wi+1=WiηiLi(Wi)yixiT.W_{i+1}=W_{i}-\eta_{i}\frac{\partial \mathcal{L}_{i}(W_{i})}{\partial y_{i}}\cdot x_{i}^T.

注意到它有类似Hebb的形式,我们再次定义

Li(W)=Wxi,Li(Wi)yi.\mathcal{L}_{i}'(W)=\langle Wx_{i},\frac{\partial \mathcal{L}_{i}(W_{i})}{\partial y_{i}} \rangle.

这是一个线性函数。我们得到

Li(W)W=Li(Wi)yixiT=const.\frac{\partial \mathcal{L}_{i}'(W)}{\partial W}=\frac{\partial \mathcal{L}_{i}(W_{i})}{\partial y_{i}}\cdot x_{i}^T=\text{const}.

换言之,

Li(W)W=Li(W)WW=Wi=const.\frac{\partial \mathcal{L}_{i}'(W)}{\partial W}=\left. \frac{\partial \mathcal{L}_{i}(W)}{\partial W} \right|_{W=W_{i}}=\text{const}.

或者说,我们在更新的时刻找到了一个函数Li\mathcal{L}_{i}',这个函数的梯度和原始函数的梯度在更新的这一点恰好相同。他们并不是处处相同,假如两个函数的梯度处处相同,那么他们之间应当只差一个常数(应该是这样吧)。从优化的角度来看,写成Li(W)W\frac{\partial \mathcal{L}_{i}'(W)}{\partial W}的形式可以说是完全没有什么意义,因为这个函数每一步都在变,每一步求的是不同函数的梯度。但是从实际的角度考虑,单样本的SGD本身就不指望Li\mathcal{L}_{i}和理论的损失函数有多接近,所以其实也可以看成我们在优化Li\mathcal{L}_{i}'

而这个新的损失是什么呢?注意到

Li(W)=Wxi,Li(Wi)yi=Wxi,ηΔyiWxi,Δyi=yi,Δyi.\mathcal{L}_{i}'(W)=\langle Wx_{i},\frac{\partial \mathcal{L}_{i}(W_{i})}{\partial y_{i}} \rangle=\langle Wx_{i},-\eta\Delta y_{i}\rangle\propto -\langle Wx_{i},\Delta y_{i}\rangle=-\langle y_{i},\Delta y_{i}\rangle.

考虑到Hebb规则一般只有对齐方向的作用,WW的更新实际上会尝试将当前线性变换的输出方向与梯度指导下输出的更新方向保持一致。至于这是什么含义,就不清楚了。但从这个意义上,我们可以将SGD也视为一个联想记忆:一个将模型的输出映射到它的梯度指导更新值的记忆。

Delta Gradient Descent

在此基础上可以进一步引入delta规则。整体的思路就是将损失函数替换为类似Hebb的形式之后,再次把点积损失替换成均方损失,从而导出delta规则。这也是很常见的推导方式了。

最终,我们应该得到

Wi+1=Wi(IηixixiT)ηiLi(Wi)yixiT.W_{i+1}=W_{i}(I-\eta_ix_ix_i^T)-\eta_{i}\frac{\partial \mathcal{L}_{i}(W_{i})}{\partial y_{i}}\cdot x_{i}^T.

这个规则用于替代传统的SGD,作为一种优化的手段。至于有没有效可能就见仁见智了,理论上似乎看不出什么东西。

此作者没有提供个人介绍。
最后更新于 2026-01-21