Linear RNNs with gating recently demonstrated competitive performance compared to Transformers in language modeling. Although their linear compute scaling in sequence length offers theoretical runtime advantages over Transformers, realizing these benefits in practice requires optimized custom kernels, as Transformers rely on the highly efficient Flash Attention kernels (Dao, 2024). Leveraging the chunkwise-parallel formulation of linear RNNs, Flash Linear Attention (FLA) (Yang & Zhang, 2024) shows that linear RNN kernels are faster than Flash Attention, by parallelizing over chunks of the input sequence. However, since the chunk size of FLA is limited, many intermediate states must be materialized in GPU memory. This leads to low arithmetic intensity and causes high memory consumption and IO cost, especially for long-context pre-training. In this work, we present Tiled Flash Linear Attention (TFLA), a novel kernel algorithm for linear RNNs, that enables arbitrary large chunk sizes and high arithmetic intensity by introducing an additional level of sequence parallelization within each chunk. First, we apply TFLA to the xLSTM with matrix memory, the mLSTM (Beck et al., 2024). Second, we propose an mLSTM variant with sigmoid input gate and reduced computation for even faster kernel runtimes at equal language modeling performance. In our speed benchmarks, we show that our new mLSTM kernels based on TFLA outperform highly optimized Flash Attention, Linear Attention and Mamba kernels, setting a new state of the art for efficient long-context sequence modeling primitives.


翻译:具有门控机制的线性循环神经网络近期在语言建模任务中展现出与Transformer相竞争的性能。尽管其在序列长度上的线性计算复杂度相较于Transformer具有理论上的运行时优势,但在实践中实现这些优势需要优化的定制内核,因为Transformer依赖于高度高效的闪存注意力内核(Dao, 2024)。基于线性循环神经网络的块式并行化形式,闪存线性注意力(FLA)(Yang & Zhang, 2024)通过并行处理输入序列的块,证明了线性循环神经网络内核比闪存注意力更快。然而,由于FLA的块大小受限,许多中间状态必须存储在GPU内存中。这导致算术强度较低,并引发高内存消耗和输入/输出成本,尤其是在长上下文预训练场景中。在本工作中,我们提出了分块式闪存线性注意力(TFLA),一种用于线性循环神经网络的新型内核算法。该算法通过在每个块内引入额外的序列并行化层级,实现了任意大的块尺寸和高算术强度。首先,我们将TFLA应用于具有矩阵记忆的xLSTM,即mLSTM(Beck et al., 2024)。其次,我们提出了一种采用Sigmoid输入门并减少计算的mLSTM变体,在保持同等语言建模性能的同时实现更快的内核运行速度。在我们的速度基准测试中,我们展示了基于TFLA的新型mLSTM内核性能优于高度优化的闪存注意力、线性注意力及Mamba内核,为高效长上下文序列建模原语树立了新的技术标杆。

0
下载
关闭预览

相关内容

Python图像处理,366页pdf,Image Operators Image Processing in Python
论文浅尝 | GEOM-GCN: Geometric Graph Convolutional Networks
开放知识图谱
14+阅读 · 2020年4月8日
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
4+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
VIP会员
相关基金
国家自然科学基金
1+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2015年12月31日
国家自然科学基金
4+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
Top
微信扫码咨询专知VIP会员