As efficient alternatives to softmax Attention, linear state-space models (SSMs) achieve constant memory and linear compute, but maintain only a lossy, fading summary of the past, often leading to inferior performance in recall oriented tasks. We propose Gated KalmaNet (GKA), a layer that reduces this gap by accounting for the full past when predicting the next token, while maintaining SSM-style efficiency. GKA achieves this by solving an online ridge regression problem at test time, with constant memory and linear compute cost in the sequence length. Drawing inspiration from the Kalman Filter, we iteratively solve the online ridge regression problem. However, a critical insight is that standard Kalman filter equations are numerically unstable in low-precision environments (like bfloat16) and difficult to parallelize in modern hardware. We address both challenges through two key innovations: (1) an adaptive regularization strategy with input-dependent gating that controls the condition number of the ridge regression, ensuring numerical stability while balancing memory retention. And (2) the use of Chebyshev Iteration instead of other conventional iterative solvers, which we demonstrate to be more stable in low-precision settings. To further improve scalability, we develop a hardware-aware chunk-wise implementation of Chebyshev Iteration along with custom kernels for backpropagating through our adaptive regularization and gating mechanisms. Empirically, GKA shows strong language understanding capabilites on short-context tasks outperforming existing SSM layers (like Mamba2, GLA and Gated DeltaNet). On long-context, GKA excels at real-world RAG and LongQA tasks up to 128k tokens, achieving more than $10$% relative improvement over other fading memory baselines.
翻译:作为softmax注意力机制的高效替代方案,线性状态空间模型(SSMs)实现了恒定内存和线性计算复杂度,但仅维持对过去信息的损失性、衰减性摘要,这通常在面向召回的任务中导致性能下降。我们提出了门控卡尔曼网络(GKA),该层通过预测下一个标记时考虑完整历史信息来缩小这一差距,同时保持SSM风格的效率。GKA通过在测试时求解在线岭回归问题实现这一目标,其内存消耗恒定且计算成本与序列长度呈线性关系。受卡尔曼滤波器启发,我们迭代求解在线岭回归问题。然而,关键洞察在于:标准卡尔曼滤波方程在低精度环境(如bfloat16)中数值不稳定,且在现代硬件中难以并行化。我们通过两项关键创新应对这两个挑战:(1)采用具有输入相关门控的自适应正则化策略,控制岭回归的条件数,在确保数值稳定性的同时平衡记忆保持能力;(2)使用切比雪夫迭代替代传统迭代求解器,我们证明其在低精度设置下具有更好的稳定性。为进一步提升可扩展性,我们开发了硬件感知的分块式切比雪夫迭代实现,以及用于自适应正则化和门控机制反向传播的自定义内核。实证表明,GKA在短上下文任务中展现出强大的语言理解能力,优于现有SSM层(如Mamba2、GLA和门控DeltaNet)。在长上下文场景中,GKA在高达128k标记的真实世界RAG和LongQA任务中表现卓越,相较于其他衰减记忆基线模型实现了超过10%的相对性能提升。