Recent advances in reasoning domains with neural networks have primarily been enabled by a training recipe that optimizes Large Language Models, previously trained to predict the next-token in a sequence, with reinforcement learning algorithms. We introduce a framework to study the success of this paradigm, and we theoretically expose the optimization mechanisms by which reinforcement learning improves over next-token prediction in this setting. We study learning from mixture distributions of short and long ``chain-of-thought'' sequences encoding a single task. In particular, when the task consists of predicting the parity of $d$ bits and long sequences are rare, we show how reinforcement learning after next-token prediction enables autoregressive transformers to generalize, whereas mere next-token prediction requires extreme statistical or computational resources to do so. We further explain how reinforcement learning leverages increased test-time computation, manifested in longer responses, to facilitate this learning process. In a simplified setting, we theoretically prove that autoregressive linear models following this training recipe can efficiently learn to predict the parity of $d$ bits as long as the proportion of long demonstrations in the data mix is not exponentially small in the input dimension $d$. Finally, we demonstrate these same phenomena in other settings, including the post-training of Llama-series models on mixture variations of common mathematical reasoning benchmarks.
翻译:近年来,神经网络在推理领域取得的进展主要得益于一种训练方法:该方案首先通过预测序列中下一词来训练大型语言模型,随后利用强化学习算法对其进行优化。我们提出了一个框架来研究这一范式的成功机制,并从理论上揭示了强化学习在此设置中如何超越下一词预测的优化原理。我们研究了从包含长短不一的“思维链”序列混合分布中学习单一任务的过程。具体而言,当任务涉及预测$d$位二进制数的奇偶性且长序列较为罕见时,我们证明了经过下一词预测预训练的强化学习能使自回归Transformer模型实现泛化,而仅使用下一词预测则需要极大的统计或计算资源才能达到相同效果。我们进一步阐释了强化学习如何利用测试阶段更长的响应(表现为更长的计算过程)来促进学习。在简化设定中,我们通过理论证明:遵循此训练流程的自回归线性模型能够高效学习预测$d$位二进制数的奇偶性,只要数据混合中长演示序列的比例不随输入维度$d$呈指数级衰减。最后,我们在其他场景中验证了相同现象,包括对Llama系列模型在常见数学推理基准的混合变体上进行后训练的过程。