Transformer language models can generate strikingly natural text by modeling language as a sequence of tokens. Yet, by relying primarily on surface-level co-occurrence statistics, they fail to form globally consistent latent representations of entities and events, lack of which contributes to brittleness in relational direction (e.g., reversal curse), contextualization errors, and data inefficiency. On the other hand, cognitive science shows that human comprehension involves converting the input linguistic stream into compact, event-like representations that persist in memory while verbatim form is short-lived. Motivated by this view, we introduce Thought Gestalt (TG) model, a recurrent Transformer that models language at two levels of abstraction - tokens and sentence-level "thought" states. TG generates the tokens of one sentence at a time while cross-attending to a memory of prior sentence representations. In TG, token and sentence representations are generated using the same set of model parameters and trained with a single objective, the next-token cross-entropy: by retaining the computation graph of sentence representations written to memory, gradients from future token losses flow backward through cross-attention to optimize the parameters generating earlier sentence vectors. In scaling experiments, TG consistently improves efficiency over matched GPT-2 runs, among other baselines, with scaling fits indicating GPT-2 requires ~5-8% more data and ~33-42% more parameters to match TG's loss. TG also reduces errors on relational direction generalization on a father-son reversal curse probe.
翻译:Transformer语言模型通过将语言建模为标记序列,能够生成惊人的自然文本。然而,由于主要依赖表层共现统计,它们未能形成全局一致的实体和事件潜在表征,这种缺失导致了关系方向(例如逆转诅咒)的脆弱性、语境化错误和数据效率低下。另一方面,认知科学表明,人类理解涉及将输入的语言流转化为紧凑的、事件式的表征,这些表征在记忆中持续存在,而逐字形式则短暂留存。受此观点启发,我们提出了思维完形(TG)模型,这是一种循环Transformer,它在两个抽象层次上对语言进行建模——标记和句子级别的“思维”状态。TG一次生成一个句子的标记,同时交叉关注先前句子表征的记忆。在TG中,标记和句子表征使用同一组模型参数生成,并通过单一目标(下一个标记的交叉熵)进行训练:通过保留写入记忆的句子表征的计算图,来自未来标记损失的梯度通过交叉注意力反向传播,以优化生成早期句子向量的参数。在扩展实验中,TG相对于匹配的GPT-2运行及其他基线,持续提高了效率,扩展拟合表明GPT-2需要约5-8%的更多数据和约33-42%的更多参数才能匹配TG的损失。TG还在父子逆转诅咒探测任务上减少了关系方向泛化的错误。