Abstract:There has been a flurry of activity around using pretrained diffusion models as informed data priors for solving inverse problems, and more generally around steering these models using reward models. Training-free methods like diffusion posterior sampling (DPS) and its many variants have offered flexible heuristic algorithms for these tasks, but when the reward is not informative enough, e.g., in hard inverse problems with low signal-to-noise ratio, these techniques veer off the data manifold, failing to produce realistic outputs. In this work, we devise a simple wrapper, ReGuidance, for boosting both the sample realism and reward achieved by these methods. Given a candidate solution $\hat{x}$ produced by an algorithm of the user's choice, we propose inverting the solution by running the unconditional probability flow ODE in reverse starting from $\hat{x}$, and then using the resulting latent as an initialization for DPS. We evaluate our wrapper on hard inverse problems like large box in-painting and super-resolution with high upscaling. Whereas state-of-the-art baselines visibly fail, we find that applying our wrapper on top of these baselines significantly boosts sample quality and measurement consistency. We complement these findings with theory proving that on certain multimodal data distributions, ReGuidance simultaneously boosts the reward and brings the candidate solution closer to the data manifold. To our knowledge, this constitutes the first rigorous algorithmic guarantee for DPS.
Abstract:Given a predictor and a loss function, how well can we predict the loss that the predictor will incur on an input? This is the problem of loss prediction, a key computational task associated with uncertainty estimation for a predictor. In a classification setting, a predictor will typically predict a distribution over labels and hence have its own estimate of the loss that it will incur, given by the entropy of the predicted distribution. Should we trust this estimate? In other words, when does the predictor know what it knows and what it does not know? In this work we study the theoretical foundations of loss prediction. Our main contribution is to establish tight connections between nontrivial loss prediction and certain forms of multicalibration, a multigroup fairness notion that asks for calibrated predictions across computationally identifiable subgroups. Formally, we show that a loss predictor that is able to improve on the self-estimate of a predictor yields a witness to a failure of multicalibration, and vice versa. This has the implication that nontrivial loss prediction is in effect no easier or harder than auditing for multicalibration. We support our theoretical results with experiments that show a robust positive correlation between the multicalibration error of a predictor and the efficacy of training a loss predictor.
Abstract:Large language models (LLMs) can exhibit undesirable and unexpected behavior in the blink of an eye. In a recent Anthropic demo, Claude switched from coding to Googling pictures of Yellowstone, and these sudden shifts in behavior have also been observed in reasoning patterns and jailbreaks. This phenomenon is not unique to autoregressive models: in diffusion models, key features of the final output are decided in narrow ``critical windows'' of the generation process. In this work we develop a simple, unifying theory to explain this phenomenon. We show that it emerges generically as the generation process localizes to a sub-population of the distribution it models. While critical windows have been studied at length in diffusion models, existing theory heavily relies on strong distributional assumptions and the particulars of Gaussian diffusion. In contrast to existing work our theory (1) applies to autoregressive and diffusion models; (2) makes no distributional assumptions; (3) quantitatively improves previous bounds even when specialized to diffusions; and (4) requires basic tools and no stochastic calculus or statistical physics-based machinery. We also identify an intriguing connection to the all-or-nothing phenomenon from statistical inference. Finally, we validate our predictions empirically for LLMs and find that critical windows often coincide with failures in problem solving for various math and reasoning benchmarks.