Alert button
Picture for Dheeraj Nagaraj

Dheeraj Nagaraj

Alert button

Near Optimal Heteroscedastic Regression with Symbiotic Learning

Jul 01, 2023
Dheeraj Baby, Aniket Das, Dheeraj Nagaraj, Praneeth Netrapalli

Figure 1 for Near Optimal Heteroscedastic Regression with Symbiotic Learning
Figure 2 for Near Optimal Heteroscedastic Regression with Symbiotic Learning

We consider the problem of heteroscedastic linear regression, where, given $n$ samples $(\mathbf{x}_i, y_i)$ from $y_i = \langle \mathbf{w}^{*}, \mathbf{x}_i \rangle + \epsilon_i \cdot \langle \mathbf{f}^{*}, \mathbf{x}_i \rangle$ with $\mathbf{x}_i \sim N(0,\mathbf{I})$, $\epsilon_i \sim N(0,1)$, we aim to estimate $\mathbf{w}^{*}$. Beyond classical applications of such models in statistics, econometrics, time series analysis etc., it is also particularly relevant in machine learning when data is collected from multiple sources of varying but apriori unknown quality. Our work shows that we can estimate $\mathbf{w}^{*}$ in squared norm up to an error of $\tilde{O}\left(\|\mathbf{f}^{*}\|^2 \cdot \left(\frac{1}{n} + \left(\frac{d}{n}\right)^2\right)\right)$ and prove a matching lower bound (upto log factors). This represents a substantial improvement upon the previous best known upper bound of $\tilde{O}\left(\|\mathbf{f}^{*}\|^2\cdot \frac{d}{n}\right)$. Our algorithm is an alternating minimization procedure with two key subroutines 1. An adaptation of the classical weighted least squares heuristic to estimate $\mathbf{w}^{*}$, for which we provide the first non-asymptotic guarantee. 2. A nonconvex pseudogradient descent procedure for estimating $\mathbf{f}^{*}$ inspired by phase retrieval. As corollaries, we obtain fast non-asymptotic rates for two important problems, linear regression with multiplicative noise and phase retrieval with multiplicative noise, both of which are of independent interest. Beyond this, the proof of our lower bound, which involves a novel adaptation of LeCam's method for handling infinite mutual information quantities (thereby preventing a direct application of standard techniques like Fano's method), could also be of broader interest for establishing lower bounds for other heteroscedastic or heavy-tailed statistical problems.

* To appear in Conference on Learning Theory 2023 (COLT 2023) 
Viaarxiv icon

Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization

Jun 15, 2023
Ramnath Kumar, Kushal Majmundar, Dheeraj Nagaraj, Arun Sai Suggala

Figure 1 for Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization
Figure 2 for Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization
Figure 3 for Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization
Figure 4 for Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization

We develop a re-weighted gradient descent technique for boosting the performance of deep neural networks. Our algorithm involves the importance weighting of data points during each optimization step. Our approach is inspired by distributionally robust optimization with $f$-divergences, which has been known to result in models with improved generalization guarantees. Our re-weighting scheme is simple, computationally efficient, and can be combined with any popular optimization algorithms such as SGD and Adam. Empirically, we demonstrate our approach's superiority on various tasks, including vanilla classification, classification with label imbalance, noisy labels, domain adaptation, and tabular representation learning. Notably, we obtain improvements of +0.7% and +1.44% over SOTA on DomainBed and Tabular benchmarks, respectively. Moreover, our algorithm boosts the performance of BERT on GLUE benchmarks by +1.94%, and ViT on ImageNet-1K by +0.9%. These results demonstrate the effectiveness of the proposed approach, indicating its potential for improving performance in diverse domains.

Viaarxiv icon

Provably Fast Finite Particle Variants of SVGD via Virtual Particle Stochastic Approximation

May 27, 2023
Aniket Das, Dheeraj Nagaraj

Figure 1 for Provably Fast Finite Particle Variants of SVGD via Virtual Particle Stochastic Approximation
Figure 2 for Provably Fast Finite Particle Variants of SVGD via Virtual Particle Stochastic Approximation

Stein Variational Gradient Descent (SVGD) is a popular variational inference algorithm which simulates an interacting particle system to approximately sample from a target distribution, with impressive empirical performance across various domains. Theoretically, its population (i.e, infinite-particle) limit dynamics is well studied but the behavior of SVGD in the finite-particle regime is much less understood. In this work, we design two computationally efficient variants of SVGD, namely VP-SVGD (which is conceptually elegant) and GB-SVGD (which is empirically effective), with provably fast finite-particle convergence rates. We introduce the notion of \emph{virtual particles} and develop novel stochastic approximations of population-limit SVGD dynamics in the space of probability measures, which are exactly implementable using a finite number of particles. Our algorithms can be viewed as specific random-batch approximations of SVGD, which are computationally more efficient than ordinary SVGD. We show that the $n$ particles output by VP-SVGD and GB-SVGD, run for $T$ steps with batch-size $K$, are at-least as good as i.i.d samples from a distribution whose Kernel Stein Discrepancy to the target is at most $O\left(\tfrac{d^{1/3}}{(KT)^{1/6}}\right)$ under standard assumptions. Our results also hold under a mild growth condition on the potential function, which is much weaker than the isoperimetric (e.g. Poincare Inequality) or information-transport conditions (e.g. Talagrand's Inequality $\mathsf{T}_1$) generally considered in prior works. As a corollary, we consider the convergence of the empirical measure (of the particles output by VP-SVGD and GB-SVGD) to the target distribution and demonstrate a \emph{double exponential improvement} over the best known finite-particle analysis of SVGD.

* 34 Pages, 2 Figures 
Viaarxiv icon

Indexability is Not Enough for Whittle: Improved, Near-Optimal Algorithms for Restless Bandits

Oct 31, 2022
Abheek Ghosh, Dheeraj Nagaraj, Manish Jain, Milind Tambe

Figure 1 for Indexability is Not Enough for Whittle: Improved, Near-Optimal Algorithms for Restless Bandits
Figure 2 for Indexability is Not Enough for Whittle: Improved, Near-Optimal Algorithms for Restless Bandits
Figure 3 for Indexability is Not Enough for Whittle: Improved, Near-Optimal Algorithms for Restless Bandits
Figure 4 for Indexability is Not Enough for Whittle: Improved, Near-Optimal Algorithms for Restless Bandits

We study the problem of planning restless multi-armed bandits (RMABs) with multiple actions. This is a popular model for multi-agent systems with applications like multi-channel communication, monitoring and machine maintenance tasks, and healthcare. Whittle index policies, which are based on Lagrangian relaxations, are widely used in these settings due to their simplicity and near-optimality under certain conditions. In this work, we first show that Whittle index policies can fail in simple and practically relevant RMAB settings, \textit{even when} the RMABs are indexable. We discuss why the optimality guarantees fail and why asymptotic optimality may not translate well to practically relevant planning horizons. We then propose an alternate planning algorithm based on the mean-field method, which can provably and efficiently obtain near-optimal policies with a large number of arms, without the stringent structural assumptions required by the Whittle index policies. This borrows ideas from existing research with some improvements: our approach is hyper-parameter free, and we provide an improved non-asymptotic analysis which has: (a) no requirement for exogenous hyper-parameters and tighter polynomial dependence on known problem parameters; (b) high probability bounds which show that the reward of the policy is reliable; and (c) matching sub-optimality lower bounds for this algorithm with respect to the number of arms, thus demonstrating the tightness of our bounds. Our extensive experimental analysis shows that the mean-field approach matches or outperforms other baselines.

* 19 pages 
Viaarxiv icon

Finite time analysis of temporal difference learning with linear function approximation: Tail averaging and regularisation

Oct 12, 2022
Gandharv Patil, Prashanth L. A., Dheeraj Nagaraj, Doina Precup

Figure 1 for Finite time analysis of temporal difference learning with linear function approximation: Tail averaging and regularisation
Figure 2 for Finite time analysis of temporal difference learning with linear function approximation: Tail averaging and regularisation
Figure 3 for Finite time analysis of temporal difference learning with linear function approximation: Tail averaging and regularisation

We study the finite-time behaviour of the popular temporal difference (TD) learning algorithm when combined with tail-averaging. We derive finite time bounds on the parameter error of the tail-averaged TD iterate under a step-size choice that does not require information about the eigenvalues of the matrix underlying the projected TD fixed point. Our analysis shows that tail-averaged TD converges at the optimal $O\left(1/t\right)$ rate, both in expectation and with high probability. In addition, our bounds exhibit a sharper rate of decay for the initial error (bias), which is an improvement over averaging all iterates. We also propose and analyse a variant of TD that incorporates regularisation. From analysis, we conclude that the regularised version of TD is useful for problems with ill-conditioned features.

Viaarxiv icon

Multi-User Reinforcement Learning with Low Rank Rewards

Oct 11, 2022
Naman Agarwal, Prateek Jain, Suhas Kowshik, Dheeraj Nagaraj, Praneeth Netrapalli

In this work, we consider the problem of collaborative multi-user reinforcement learning. In this setting there are multiple users with the same state-action space and transition probabilities but with different rewards. Under the assumption that the reward matrix of the $N$ users has a low-rank structure -- a standard and practically successful assumption in the offline collaborative filtering setting -- the question is can we design algorithms with significantly lower sample complexity compared to the ones that learn the MDP individually for each user. Our main contribution is an algorithm which explores rewards collaboratively with $N$ user-specific MDPs and can learn rewards efficiently in two key settings: tabular MDPs and linear MDPs. When $N$ is large and the rank is constant, the sample complexity per MDP depends logarithmically over the size of the state-space, which represents an exponential reduction (in the state-space size) when compared to the standard ``non-collaborative'' algorithms.

Viaarxiv icon

Entropic Convergence of Random Batch Methods for Interacting Particle Diffusion

Jun 08, 2022
Dheeraj Nagaraj

Figure 1 for Entropic Convergence of Random Batch Methods for Interacting Particle Diffusion
Figure 2 for Entropic Convergence of Random Batch Methods for Interacting Particle Diffusion

We propose a co-variance corrected random batch method for interacting particle systems. By establishing a certain entropic central limit theorem, we provide entropic convergence guarantees for the law of the entire trajectories of all particles of the proposed method to the law of the trajectories of the discrete time interacting particle system whenever the batch size $B \gg (\alpha n)^{\frac{1}{3}}$ (where $n$ is the number of particles and $\alpha$ is the time discretization parameter). This in turn implies that the outputs of these methods are nearly \emph{statistically indistinguishable} when $B$ is even moderately large. Previous works mainly considered convergence in Wasserstein distance with required stringent assumptions on the potentials or the bounds had an exponential dependence on the time horizon. This work makes minimal assumptions on the interaction potentials and in particular establishes that even when the particle trajectories diverge to infinity, they do so in the same way for both the methods. Such guarantees are very useful in light of the recent advances in interacting particle based algorithms for sampling.

* No figures, like usual. F 
Viaarxiv icon

Look Back When Surprised: Stabilizing Reverse Experience Replay for Neural Approximation

Jun 07, 2022
Ramnath Kumar, Dheeraj Nagaraj

Figure 1 for Look Back When Surprised: Stabilizing Reverse Experience Replay for Neural Approximation
Figure 2 for Look Back When Surprised: Stabilizing Reverse Experience Replay for Neural Approximation
Figure 3 for Look Back When Surprised: Stabilizing Reverse Experience Replay for Neural Approximation
Figure 4 for Look Back When Surprised: Stabilizing Reverse Experience Replay for Neural Approximation

Experience replay methods, which are an essential part of reinforcement learning(RL) algorithms, are designed to mitigate spurious correlations and biases while learning from temporally dependent data. Roughly speaking, these methods allow us to draw batched data from a large buffer such that these temporal correlations do not hinder the performance of descent algorithms. In this experimental work, we consider the recently developed and theoretically rigorous reverse experience replay (RER), which has been shown to remove such spurious biases in simplified theoretical settings. We combine RER with optimistic experience replay (OER) to obtain RER++, which is stable under neural function approximation. We show via experiments that this has a better performance than techniques like prioritized experience replay (PER) on various tasks, with a significantly smaller computational complexity. It is well known in the RL literature that choosing examples greedily with the largest TD error (as in OER) or forming mini-batches with consecutive data points (as in RER) leads to poor performance. However, our method, which combines these techniques, works very well.

Viaarxiv icon

Online Target Q-learning with Reverse Experience Replay: Efficiently finding the Optimal Policy for Linear MDPs

Oct 19, 2021
Naman Agarwal, Syomantak Chaudhuri, Prateek Jain, Dheeraj Nagaraj, Praneeth Netrapalli

Figure 1 for Online Target Q-learning with Reverse Experience Replay: Efficiently finding the Optimal Policy for Linear MDPs
Figure 2 for Online Target Q-learning with Reverse Experience Replay: Efficiently finding the Optimal Policy for Linear MDPs
Figure 3 for Online Target Q-learning with Reverse Experience Replay: Efficiently finding the Optimal Policy for Linear MDPs
Figure 4 for Online Target Q-learning with Reverse Experience Replay: Efficiently finding the Optimal Policy for Linear MDPs

Q-learning is a popular Reinforcement Learning (RL) algorithm which is widely used in practice with function approximation (Mnih et al., 2015). In contrast, existing theoretical results are pessimistic about Q-learning. For example, (Baird, 1995) shows that Q-learning does not converge even with linear function approximation for linear MDPs. Furthermore, even for tabular MDPs with synchronous updates, Q-learning was shown to have sub-optimal sample complexity (Li et al., 2021;Azar et al., 2013). The goal of this work is to bridge the gap between practical success of Q-learning and the relatively pessimistic theoretical results. The starting point of our work is the observation that in practice, Q-learning is used with two important modifications: (i) training with two networks, called online network and target network simultaneously (online target learning, or OTL) , and (ii) experience replay (ER) (Mnih et al., 2015). While they have been observed to play a significant role in the practical success of Q-learning, a thorough theoretical understanding of how these two modifications improve the convergence behavior of Q-learning has been missing in literature. By carefully combining Q-learning with OTL and reverse experience replay (RER) (a form of experience replay), we present novel methods Q-Rex and Q-RexDaRe (Q-Rex + data reuse). We show that Q-Rex efficiently finds the optimal policy for linear MDPs (or more generally for MDPs with zero inherent Bellman error with linear approximation (ZIBEL)) and provide non-asymptotic bounds on sample complexity -- the first such result for a Q-learning method for this class of MDPs under standard assumptions. Furthermore, we demonstrate that Q-RexDaRe in fact achieves near optimal sample complexity in the tabular setting, improving upon the existing results for vanilla Q-learning.

* Under Review, V2 has updated acknowledgements 
Viaarxiv icon