Abstract:In decoder-only (causal) transformers, the computation graph created by causal masking routes information through both direct-path attention and indirect paths formed by intermediate tokens. We denote these indirect paths between token pairs as their runways. We argue that certain failure modes of causal transformers as observed by a growing body of recent works are likely exacerbated by a misalignment between these two information propagation modes. We formalize runway cascade as a phenomenon whereby this misalignment results in redundancies and irrelevant information cascading to token representations despite adequately learned attention patterns. As a solution, we propose runway-aware rewiring as a more explicit way of incorporating runway context directly into each token's direct-path attention. This mechanism re-wires the attention pattern for each token based on a summary of its runway landscape, enabling awareness of accumulating representational influences and allowing for more balanced information propagation. Our proposed methodology introduces no additional parameters and can seamlessly be integrated into standard attention mechanism. Empirically, our rewired transformer results in steady improvements in general language modeling as well as noticeably stronger information retrieval and extrapolation abilities compared to standard transformers.
Abstract:Transformers and more specifically decoder-only transformers dominate modern LLM architectures. While they have shown to work exceptionally well, they are not without issues, resulting in surprising failure modes and predictably asymmetric performance degradation. This article is a study of many of these observed failure modes of transformers through the lens of graph neural network (GNN) theory. We first make the case that much of deep learning, including transformers, is about learnable information mixing and propagation. This makes the study of model failure modes a study of bottlenecks in information propagation. This naturally leads to GNN theory, where there is already a rich literature on information propagation bottlenecks and theoretical failure modes of models. We then make the case that many issues faced by GNNs are also experienced by transformers. In addition, we analyze how the causal nature of decoder-only transformers create interesting geometric properties in information propagation, resulting in predictable and potentially devastating failure modes. Finally, we observe that existing solutions in transformer research tend to be ad-hoc and driven by intuition rather than grounded theoretical motivation. As such, we unify many such solutions under a more theoretical perspective, providing insight into why they work, what problem they are actually solving, and how they can be further improved to target specific failure modes of transformers. Overall, this article is an attempt to bridge the gap between observed failure modes in transformers and a general lack of theoretical understanding of them in this space.