We explore Multi-Head FFN (MH-FFN) as a replacement of FFN in the Transformer architecture, motivated by the structural similarity between single-head attention and FFN. While multi-head mechanisms enhance expressivity in attention, naively applying them to FFNs faces two challenges: memory consumption scaling with the head count, and an imbalanced ratio between the growing intermediate size and the fixed head dimension as models scale, which degrades scalability and expressive power. To address these challenges, we propose Flash Multi-Head FFN (FlashMHF), with two key innovations: an I/O-aware fused kernel computing outputs online in SRAM akin to FlashAttention, and a design using dynamically weighted parallel sub-networks to maintain a balanced ratio between intermediate and head dimensions. Validated on models from 128M to 1.3B parameters, FlashMHF consistently improves perplexity and downstream task accuracy over SwiGLU FFNs, while reducing peak memory usage by 3-5x and accelerating inference by up to 1.08x. Our work establishes the multi-head design as a superior architectural principle for FFNs, presenting FlashMHF as a powerful, efficient, and scalable alternative to FFNs in Transformers.


翻译:我们探索了多头前馈网络(MH-FFN)作为 Transformer 架构中 FFN 的替代方案,其动机在于单头注意力机制与前馈网络之间的结构相似性。虽然多头机制增强了注意力模块的表达能力,但将其直接应用于 FFN 面临两个挑战:内存消耗随头数线性增长,以及模型规模扩大时不断增长的中间维度与固定头维度之间的比例失衡,这降低了可扩展性和表达能力。为解决这些挑战,我们提出了 Flash 多头前馈网络(FlashMHF),其包含两项关键创新:一种类似于 FlashAttention 的 I/O 感知融合内核,可在 SRAM 中在线计算输出;以及一种采用动态加权并行子网络的设计,以保持中间维度与头维度之间的平衡比例。在参数规模从 1.28 亿到 13 亿的模型上进行验证,FlashMHF 相比 SwiGLU FFN 在困惑度和下游任务准确率上持续提升,同时将峰值内存使用降低 3-5 倍,推理速度最高提升 1.08 倍。我们的工作确立了多头设计作为 FFN 的优越架构原则,并将 FlashMHF 呈现为 Transformer 中 FFN 的一种强大、高效且可扩展的替代方案。

0
下载
关闭预览

相关内容

FlowQA: Grasping Flow in History for Conversational Machine Comprehension
专知会员服务
34+阅读 · 2019年10月18日
Stabilizing Transformers for Reinforcement Learning
专知会员服务
60+阅读 · 2019年10月17日
Transferring Knowledge across Learning Processes
CreateAMind
29+阅读 · 2019年5月18日
Unsupervised Learning via Meta-Learning
CreateAMind
43+阅读 · 2019年1月3日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
IJCAI | Cascade Dynamics Modeling with Attention-based RNN
KingsGarden
13+阅读 · 2017年7月16日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
Arxiv
10+阅读 · 2021年12月9日
Arxiv
15+阅读 · 2019年11月26日
VIP会员
相关资讯
Transferring Knowledge across Learning Processes
CreateAMind
29+阅读 · 2019年5月18日
Unsupervised Learning via Meta-Learning
CreateAMind
43+阅读 · 2019年1月3日
STRCF for Visual Object Tracking
统计学习与视觉计算组
15+阅读 · 2018年5月29日
Focal Loss for Dense Object Detection
统计学习与视觉计算组
12+阅读 · 2018年3月15日
IJCAI | Cascade Dynamics Modeling with Attention-based RNN
KingsGarden
13+阅读 · 2017年7月16日
相关论文
Arxiv
10+阅读 · 2021年12月9日
Arxiv
15+阅读 · 2019年11月26日
相关基金
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
3+阅读 · 2015年12月31日
国家自然科学基金
2+阅读 · 2015年12月31日
国家自然科学基金
0+阅读 · 2014年12月31日
国家自然科学基金
2+阅读 · 2014年12月31日
Top
微信扫码咨询专知VIP会员