Abstract:Protein-protein interactions (PPIs) govern nearly all cellular processes, yet computational methods for identifying binding partners typically produce ranked predictions without mechanistic justification. This creates a fundamental barrier to adoption because biologists cannot assess whether predictions reflect genuine biochemical insight or spurious correlations. We present \textbf{Protein Thoughts}, a framework that reformulates PPI discovery as an interpretable search problem with explicit reasoning. The system decomposes binding evidence into four biologically meaningful signals: sequence similarity reflecting evolutionary relationships, structural complementarity capturing geometric fit, interface balance, and chemical compatibility encoding residue-level interactions. Rather than collapsing these signals into an opaque score, we preserve their individual contributions through a transparent value function that enables both ranking and auditing. To navigate large candidate spaces efficiently, we introduce hypothesis-guided entropy-regularized Tree-of-Thoughts search. A fine-tuned language model generates search directives from embedding-derived features, classifying candidates as high-priority, exploratory, or skippable. These directives condition a Boltzmann policy that balances exploitation with entropy-driven exploration, while hypothesis-aware pruning prevents premature abandonment of promising candidates. For candidates exhibiting score disagreement, hypothesis-conditioned embedding-space flow matching transports protein embeddings toward the binder manifold. On the SHS148k benchmark, Protein Thoughts achieves mean best-binder rank of 11.2 versus 47.7 for an entropic tree search baseline, a 76% improvement, and for binding prediction the trained value function achieves $91.08 \pm 0.19$ Micro-F1, outperforming existing PPI methods on the same dataset.
Abstract:A critical step for reliable large language models (LLMs) use in healthcare is to attribute predictions to their training data, akin to a medical case study. This requires token-level precision: pinpointing not just which training examples influence a decision, but which tokens within them are responsible. While influence functions offer a principled framework for this, prior work is restricted to autoregressive settings and relies on an implicit assumption of token independence, rendering their identified influences unreliable. We introduce a flexible framework that infers token-level influence through a latent mediation approach for general prediction tasks. Our method attaches sparse autoencoders to any layer of a pretrained LLM to learn a basis of approximately independent latent features. Unlike prior methods where influence decomposes additively across tokens, influence computed over latent features is inherently non-decomposable. To address this, we introduce a novel method using Jacobian-vector products. Token-level influence is obtained by propagating latent attributions back to the input space via token activation patterns. We scale our approach using efficient inverse-Hessian approximations. Experiments on medical benchmarks show our approach identifies sparse, interpretable sets of tokens that jointly influence predictions. Our framework enhances trust and enables model auditing, generalizing to high-stakes domain requiring transparent and accountable decisions.
Abstract:We derive finite-particle rates for the regularized Stein variational gradient descent (R-SVGD) algorithm introduced by He et al. (2024) that corrects the constant-order bias of the SVGD by applying a resolvent-type preconditioner to the kernelized Wasserstein gradient. For the resulting interacting $N$-particle system, we establish explicit non-asymptotic bounds for time-averaged (annealed) empirical measures, illustrating convergence in the \emph{true} (non-kernelized) Fisher information and, under a $\mathrm{W}_1\mathrm{I}$ condition on the target, corresponding $\mathrm{W}_1$ convergence for a large class of smooth kernels. Our analysis covers both continuous- and discrete-time dynamics and yields principled tuning rules for the regularization parameter, step size, and averaging horizon that quantify the trade-off between approximating the Wasserstein gradient flow and controlling finite-particle estimation error.
Abstract:Single-cell RNA sequencing (scRNA-seq) enables the study of cellular heterogeneity. Yet, clustering accuracy, and with it downstream analyses based on cell labels, remain challenging due to measurement noise and biological variability. In standard latent spaces (e.g., obtained through PCA), data from different cell types can be projected close together, making accurate clustering difficult. We introduce a latent plug-and-play diffusion framework that separates the observation and denoising space. This separation is operationalized through a novel Gibbs sampling procedure: the learned diffusion prior is applied in a low-dimensional latent space to perform denoising, while to steer this process, noise is reintroduced into the original high-dimensional observation space. This unique "input-space steering" ensures the denoising trajectory remains faithful to the original data structure. Our approach offers three key advantages: (1) adaptive noise handling via a tunable balance between prior and observed data; (2) uncertainty quantification through principled uncertainty estimates for downstream analysis; and (3) generalizable denoising by leveraging clean reference data to denoise noisier datasets, and via averaging, improve quality beyond the training set. We evaluate robustness on both synthetic and real single-cell genomics data. Our method improves clustering accuracy on synthetic data across varied noise levels and dataset shifts. On real-world single-cell data, our method demonstrates improved biological coherence in the resulting cell clusters, with cluster boundaries that better align with known cell type markers and developmental trajectories.
Abstract:Efficient matrix trace estimation is essential for scalable computation of log-determinants, matrix norms, and distributional divergences. In many large-scale applications, the matrices involved are too large to store or access in full, making even a single matrix-vector (mat-vec) product infeasible. Instead, one often has access only to small subblocks of the matrix or localized matrix-vector products on restricted index sets. Hutch++ achieves optimal convergence rate but relies on randomized SVD and assumes full mat-vec access, making it difficult to apply in these constrained settings. We propose the Block-Orthonormal Stochastic Lanczos Quadrature (BOLT), which matches Hutch++ accuracy with a simpler implementation based on orthonormal block probes and Lanczos iterations. BOLT builds on the Stochastic Lanczos Quadrature (SLQ) framework, which combines random probing with Krylov subspace methods to efficiently approximate traces of matrix functions, and performs better than Hutch++ in near flat-spectrum regimes. To address memory limitations and partial access constraints, we introduce Subblock SLQ, a variant of BOLT that operates only on small principal submatrices. As a result, this framework yields a proxy KL divergence estimator and an efficient method for computing the Wasserstein-2 distance between Gaussians - both compatible with low-memory and partial-access regimes. We provide theoretical guarantees and demonstrate strong empirical performance across a range of high-dimensional settings.




Abstract:In-context learning (ICL)-the ability of transformer-based models to perform new tasks from examples provided at inference time-has emerged as a hallmark of modern language models. While recent works have investigated the mechanisms underlying ICL, its feasibility under formal privacy constraints remains largely unexplored. In this paper, we propose a differentially private pretraining algorithm for linear attention heads and present the first theoretical analysis of the privacy-accuracy trade-off for ICL in linear regression. Our results characterize the fundamental tension between optimization and privacy-induced noise, formally capturing behaviors observed in private training via iterative methods. Additionally, we show that our method is robust to adversarial perturbations of training prompts, unlike standard ridge regression. All theoretical findings are supported by extensive simulations across diverse settings.




Abstract:We address the challenge of causal discovery in structural equation models with additive noise without imposing additional assumptions on the underlying data-generating process. We introduce local search in additive noise model (LoSAM), which generalizes an existing nonlinear method that leverages local causal substructures to the general additive noise setting, allowing for both linear and nonlinear causal mechanisms. We show that LoSAM achieves polynomial runtime, and improves runtime and efficiency by exploiting new substructures to minimize the conditioning set at each step. Further, we introduce a variant of LoSAM, LoSAM-UC, that is robust to unmeasured confounding among roots, a property that is often not satisfied by functional-causal-model-based methods. We numerically demonstrate the utility of LoSAM, showing that it outperforms existing benchmarks.
Abstract:We provide finite-particle convergence rates for the Stein Variational Gradient Descent (SVGD) algorithm in the Kernel Stein Discrepancy ($\mathsf{KSD}$) and Wasserstein-2 metrics. Our key insight is the observation that the time derivative of the relative entropy between the joint density of $N$ particle locations and the $N$-fold product target measure, starting from a regular initial distribution, splits into a dominant `negative part' proportional to $N$ times the expected $\mathsf{KSD}^2$ and a smaller `positive part'. This observation leads to $\mathsf{KSD}$ rates of order $1/\sqrt{N}$, providing a near optimal double exponential improvement over the recent result by~\cite{shi2024finite}. Under mild assumptions on the kernel and potential, these bounds also grow linearly in the dimension $d$. By adding a bilinear component to the kernel, the above approach is used to further obtain Wasserstein-2 convergence. For the case of `bilinear + Mat\'ern' kernels, we derive Wasserstein-2 rates that exhibit a curse-of-dimensionality similar to the i.i.d. setting. We also obtain marginal convergence and long-time propagation of chaos results for the time-averaged particle laws.




Abstract:Learning the unique directed acyclic graph corresponding to an unknown causal model is a challenging task. Methods based on functional causal models can identify a unique graph, but either suffer from the curse of dimensionality or impose strong parametric assumptions. To address these challenges, we propose a novel hybrid approach for global causal discovery in observational data that leverages local causal substructures. We first present a topological sorting algorithm that leverages ancestral relationships in linear structural equation models to establish a compact top-down hierarchical ordering, encoding more causal information than linear orderings produced by existing methods. We demonstrate that this approach generalizes to nonlinear settings with arbitrary noise. We then introduce a nonparametric constraint-based algorithm that prunes spurious edges by searching for local conditioning sets, achieving greater accuracy than current methods. We provide theoretical guarantees for correctness and worst-case polynomial time complexities, with empirical validation on synthetic data.




Abstract:We conduct a comprehensive investigation into the dynamics of gradient descent using large-order constant step-sizes in the context of quadratic regression models. Within this framework, we reveal that the dynamics can be encapsulated by a specific cubic map, naturally parameterized by the step-size. Through a fine-grained bifurcation analysis concerning the step-size parameter, we delineate five distinct training phases: (1) monotonic, (2) catapult, (3) periodic, (4) chaotic, and (5) divergent, precisely demarcating the boundaries of each phase. As illustrations, we provide examples involving phase retrieval and two-layer neural networks employing quadratic activation functions and constant outer-layers, utilizing orthogonal training data. Our simulations indicate that these five phases also manifest with generic non-orthogonal data. We also empirically investigate the generalization performance when training in the various non-monotonic (and non-divergent) phases. In particular, we observe that performing an ergodic trajectory averaging stabilizes the test error in non-monotonic (and non-divergent) phases.