As efficient alternatives to softmax Attention, linear State-Space Models (SSMs) achieve constant memory and linear compute, but maintain only a lossy, fading summary of the past, often leading to inferior performance in recall-oriented tasks. We propose Gated KalmaNet (GKA), a layer that accounts for the full past while maintaining SSM-style efficiency. We ground our approach in the Kalman Filter (KF) framework, which provides a principled solution for optimal inference in dynamical systems. We show that several existing SSM layers (DeltaNet, Gated DeltaNet, and Kimi Delta Attention) are approximations to the KF recurrence that assume identity error covariance, thereby ignoring how past measurements (keys and values) should optimally influence state updates. In contrast, GKA computes the exact Kalman gain by maintaining the full error covariance. Under a steady-state assumption that enables parallelization, this reduces to solving an online ridge regression problem with constant memory and linear compute cost. A critical insight is that standard KF equations are numerically unstable in low-precision environments (like bfloat16) and hard to parallelize on modern hardware. We address this through: (1) adaptive regularization with input-dependent gating to control the condition number of the ridge regression for numerical stability, and (2) Chebyshev Iteration, which we show is more stable than conventional iterative solvers in low-precision settings. We further develop hardware-aware chunk-wise kernels to enable efficient training. Empirically, GKA outperforms existing SSM layers (like Mamba2 and Gated DeltaNet) on short-context tasks and achieves more than 10\% relative improvement on long-context RAG and LongQA tasks up to 128k tokens.


翻译:作为softmax注意力的高效替代方案,线性状态空间模型(SSMs)实现了恒定内存和线性计算,但仅维持对过去信息的损失性衰减摘要,这通常在面向召回的任务中导致性能下降。我们提出了门控卡尔曼网络(GKA),这是一种在保持SSM风格效率的同时完整考虑过去信息的层。我们将方法建立在卡尔曼滤波器(KF)框架上,该框架为动态系统中的最优推断提供了理论解决方案。我们证明,几种现有的SSM层(DeltaNet、门控DeltaNet和Kimi Delta注意力)都是KF递推的近似,它们假设误差协方差为单位矩阵,从而忽略了过去测量值(键和值)应如何最优地影响状态更新。相比之下,GKA通过维护完整的误差协方差来计算精确的卡尔曼增益。在支持并行化的稳态假设下,这简化为求解一个具有恒定内存和线性计算成本的在线岭回归问题。一个关键见解是,标准KF方程在低精度环境(如bfloat16)中数值不稳定,且难以在现代硬件上并行化。我们通过以下方式解决此问题:(1)采用输入依赖的门控自适应正则化来控制岭回归的条件数以确保数值稳定性;(2)切比雪夫迭代法,我们证明其在低精度设置下比传统迭代求解器更稳定。我们进一步开发了硬件感知的分块内核以实现高效训练。实验表明,GKA在短上下文任务上优于现有SSM层(如Mamba2和门控DeltaNet),并在长达128k词元的长上下文RAG和LongQA任务上实现了超过10%的相对性能提升。

0
下载
关闭预览

相关内容

FlowQA: Grasping Flow in History for Conversational Machine Comprehension
专知会员服务
34+阅读 · 2019年10月18日
Keras François Chollet 《Deep Learning with Python 》, 386页pdf
专知会员服务
163+阅读 · 2019年10月12日
Transferring Knowledge across Learning Processes
CreateAMind
29+阅读 · 2019年5月18日
Unsupervised Learning via Meta-Learning
CreateAMind
43+阅读 · 2019年1月3日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
IJCAI | Cascade Dynamics Modeling with Attention-based RNN
KingsGarden
13+阅读 · 2017年7月16日
国家自然科学基金
13+阅读 · 2017年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
Arxiv
13+阅读 · 2023年2月7日
VIP会员
相关资讯
Transferring Knowledge across Learning Processes
CreateAMind
29+阅读 · 2019年5月18日
Unsupervised Learning via Meta-Learning
CreateAMind
43+阅读 · 2019年1月3日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
IJCAI | Cascade Dynamics Modeling with Attention-based RNN
KingsGarden
13+阅读 · 2017年7月16日
相关基金
国家自然科学基金
13+阅读 · 2017年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
Top
微信扫码咨询专知VIP会员