We develop an algorithm which, given a trained transformer model $\mathcal{M}$ as input, as well as a string of tokens $s$ of length $n_{fix}$ and an integer $n_{free}$, can generate a mathematical proof that $\mathcal{M}$ is ``overwhelmed'' by $s$, in time and space $\widetilde{O}(n_{fix}^2 + n_{free}^3)$. We say that $\mathcal{M}$ is ``overwhelmed'' by $s$ when the output of the model evaluated on this string plus any additional string $t$, $\mathcal{M}(s + t)$, is completely insensitive to the value of the string $t$ whenever length($t$) $\leq n_{free}$. Along the way, we prove a particularly strong worst-case form of ``over-squashing'', which we use to bound the model's behavior. Our technique uses computer-aided proofs to establish this type of operationally relevant guarantee about transformer models. We empirically test our algorithm on a single layer transformer complete with an attention head, layer-norm, MLP/ReLU layers, and RoPE positional encoding. We believe that this work is a stepping stone towards the difficult task of obtaining useful guarantees for trained transformer models.
翻译:暂无翻译