Abstract:We consider the Top-$K$ selection problem, which aims to identify the largest-$K$ elements from an array. Top-$K$ selection arises in many machine learning algorithms and often becomes a bottleneck on accelerators, which are optimized for dense matrix multiplications. To address this problem, \citet{chern2022tpuknnknearestneighbor} proposed a fast two-stage \textit{approximate} Top-$K$ algorithm: (i) partition the input array and select the top-$1$ element from each partition, (ii) sort this \textit{smaller subset} and return the top $K$ elements. In this paper, we consider a generalized version of this algorithm, where the first stage selects top-$K'$ elements, for some $1 \leq K' \leq K$, from each partition. Our contributions are as follows: (i) we derive an expression for the expected recall of this generalized algorithm and show that choosing $K' > 1$ with fewer partitions in the first stage reduces the input size to the second stage more effectively while maintaining the same expected recall as the original algorithm, (ii) we derive a bound on the expected recall for the original algorithm in \citet{chern2022tpuknnknearestneighbor} that is provably tighter by a factor of $2$ than the one in that paper, and (iii) we implement our algorithm on Cloud TPUv5e and achieve around an order of magnitude speedups over the original algorithm without sacrificing recall on real-world tasks.
Abstract:The autoregressive nature of conventional large language models (LLMs) inherently limits inference speed, as tokens are generated sequentially. While speculative and parallel decoding techniques attempt to mitigate this, they face limitations: either relying on less accurate smaller models for generation or failing to fully leverage the base LLM's representations. We introduce a novel architecture, Tandem transformers, to address these issues. This architecture uniquely combines (1) a small autoregressive model and (2) a large model operating in block mode (processing multiple tokens simultaneously). The small model's predictive accuracy is substantially enhanced by granting it attention to the large model's richer representations. On the PaLM2 pretraining dataset, a tandem of PaLM2-Bison and PaLM2-Gecko demonstrates a 3.3% improvement in next-token prediction accuracy over a standalone PaLM2-Gecko, offering a 1.16x speedup compared to a PaLM2-Otter model with comparable downstream performance. We further incorporate the tandem model within the speculative decoding (SPEED) framework where the large model validates tokens from the small model. This ensures that the Tandem of PaLM2-Bison and PaLM2-Gecko achieves substantial speedup (around 1.14x faster than using vanilla PaLM2-Gecko in SPEED) while maintaining identical downstream task accuracy.