Although transformer architectures have achieved state-of-the-art performance across diverse domains, their quadratic computational complexity with respect to sequence length remains a significant bottleneck, particularly for latency-sensitive long-context applications. While recent linear-complexity alternatives are increasingly powerful, effectively training them from scratch is still resource-intensive. To overcome these limitations, we propose LAWCAT (Linear Attention with Convolution Across Time), a novel linearization framework designed to efficiently transfer the capabilities of pre-trained transformers into a performant linear attention architecture. LAWCAT integrates causal Conv1D layers to enhance local dependency modeling and employs normalized gated linear attention to improve generalization across varying context lengths. Our comprehensive evaluations demonstrate that, distilling Mistral-7B with only 1K-length sequences yields over 90\% passkey retrieval accuracy up to 22K tokens, significantly extending its effective context window. Similarly, Llama3.2-1B LAWCAT variant achieves competitive performance on S-NIAH 1\&2\&3 tasks (1K-8K context length) and BABILong benchmark (QA2\&QA3, 0K-16K context length), requiring less than 0.1\% pre-training tokens compared with pre-training models. Furthermore, LAWCAT exhibits faster prefill speeds than FlashAttention-2 for sequences exceeding 8K tokens. LAWCAT thus provides an efficient pathway to high-performance, long-context linear models suitable for edge deployment, reducing reliance on extensive long-sequence training data and computational resources. Code is released at: https://github.com/zeyuliu1037/LAWCAT


翻译:尽管Transformer架构已在多个领域取得了最先进的性能,但其随序列长度呈二次增长的计算复杂度仍是显著瓶颈,尤其在对延迟敏感的长上下文应用中。尽管近年来线性复杂度的替代方案日益强大,但从头开始有效训练这些模型仍需要大量资源。为克服这些限制,我们提出了LAWCAT(跨时间卷积的线性注意力),这是一种新颖的线性化框架,旨在高效地将预训练Transformer的能力迁移至高性能的线性注意力架构中。LAWCAT整合了因果Conv1D层以增强局部依赖建模,并采用归一化门控线性注意力来提升对不同上下文长度的泛化能力。我们的全面评估表明,仅使用1K长度的序列对Mistral-7B进行蒸馏,即可在高达22K令牌的范围内实现超过90%的密钥检索准确率,显著扩展了其有效上下文窗口。类似地,Llama3.2-1B的LAWCAT变体在S-NIAH 1&2&3任务(1K-8K上下文长度)和BABILong基准测试(QA2&QA3,0K-16K上下文长度)上取得了有竞争力的性能,且所需的预训练令牌量不到预训练模型的0.1%。此外,对于超过8K令牌的序列,LAWCAT的预填充速度比FlashAttention-2更快。因此,LAWCAT为适用于边缘部署的高性能、长上下文线性模型提供了一条高效路径,减少了对大量长序列训练数据和计算资源的依赖。代码发布于:https://github.com/zeyuliu1037/LAWCAT

0
下载
关闭预览

相关内容

【ICML2024】TIMEX++: 通过信息瓶颈学习时间序列解释
专知会员服务
17+阅读 · 2024年5月16日
【AAAI2024】KAM-CoT: 知识增强的多模态思维链推理
专知会员服务
45+阅读 · 2024年1月24日
【MIT】硬负样本的对比学习
专知
13+阅读 · 2020年10月15日
MNIST入门:贝叶斯方法
Python程序员
23+阅读 · 2017年7月3日
国家自然科学基金
46+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
6+阅读 · 2014年12月31日
国家自然科学基金
6+阅读 · 2014年12月31日
VIP会员
相关资讯
【MIT】硬负样本的对比学习
专知
13+阅读 · 2020年10月15日
MNIST入门:贝叶斯方法
Python程序员
23+阅读 · 2017年7月3日
相关基金
国家自然科学基金
46+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
6+阅读 · 2014年12月31日
国家自然科学基金
6+阅读 · 2014年12月31日
Top
微信扫码咨询专知VIP会员