Alert button
Picture for Jason D. Lee

Jason D. Lee

Alert button

Settling the Sample Complexity of Online Reinforcement Learning

Jul 25, 2023
Zihan Zhang, Yuxin Chen, Jason D. Lee, Simon S. Du

Figure 1 for Settling the Sample Complexity of Online Reinforcement Learning

A central issue lying at the heart of online reinforcement learning (RL) is data efficiency. While a number of recent works achieved asymptotically minimal regret in online RL, the optimality of these results is only guaranteed in a ``large-sample'' regime, imposing enormous burn-in cost in order for their algorithms to operate optimally. How to achieve minimax-optimal regret without incurring any burn-in cost has been an open problem in RL theory. We settle this problem for the context of finite-horizon inhomogeneous Markov decision processes. Specifically, we prove that a modified version of Monotonic Value Propagation (MVP), a model-based algorithm proposed by \cite{zhang2020reinforcement}, achieves a regret on the order of (modulo log factors) \begin{equation*} \min\big\{ \sqrt{SAH^3K}, \,HK \big\}, \end{equation*} where $S$ is the number of states, $A$ is the number of actions, $H$ is the planning horizon, and $K$ is the total number of episodes. This regret matches the minimax lower bound for the entire range of sample size $K\geq 1$, essentially eliminating any burn-in requirement. It also translates to a PAC sample complexity (i.e., the number of episodes needed to yield $\varepsilon$-accuracy) of $\frac{SAH^3}{\varepsilon^2}$ up to log factor, which is minimax-optimal for the full $\varepsilon$-range. Further, we extend our theory to unveil the influences of problem-dependent quantities like the optimal value/cost and certain variances. The key technical innovation lies in the development of a new regret decomposition strategy and a novel analysis paradigm to decouple complicated statistical dependency -- a long-standing challenge facing the analysis of online RL in the sample-hungry regime.

Viaarxiv icon

Teaching Arithmetic to Small Transformers

Jul 07, 2023
Nayoung Lee, Kartik Sreenivasan, Jason D. Lee, Kangwook Lee, Dimitris Papailiopoulos

Large language models like GPT-4 exhibit emergent capabilities across general-purpose tasks, such as basic arithmetic, when trained on extensive text data, even though these tasks are not explicitly encoded by the unsupervised, next-token prediction objective. This study investigates how small transformers, trained from random initialization, can efficiently learn arithmetic operations such as addition, multiplication, and elementary functions like square root, using the next-token prediction objective. We first demonstrate that conventional training data is not the most effective for arithmetic learning, and simple formatting changes can significantly improve accuracy. This leads to sharp phase transitions as a function of training data scale, which, in some cases, can be explained through connections to low-rank matrix completion. Building on prior work, we then train on chain-of-thought style data that includes intermediate step results. Even in the complete absence of pretraining, this approach significantly and simultaneously improves accuracy, sample complexity, and convergence speed. We also study the interplay between arithmetic and text data during training and examine the effects of few-shot prompting, pretraining, and model scale. Additionally, we discuss length generalization challenges. Our work highlights the importance of high-quality, instructive data that considers the particular characteristics of the next-word prediction objective for rapidly eliciting arithmetic capabilities.

Viaarxiv icon

Scaling In-Context Demonstrations with Structured Attention

Jul 05, 2023
Tianle Cai, Kaixuan Huang, Jason D. Lee, Mengdi Wang

Figure 1 for Scaling In-Context Demonstrations with Structured Attention
Figure 2 for Scaling In-Context Demonstrations with Structured Attention
Figure 3 for Scaling In-Context Demonstrations with Structured Attention
Figure 4 for Scaling In-Context Demonstrations with Structured Attention

The recent surge of large language models (LLMs) highlights their ability to perform in-context learning, i.e., "learning" to perform a task from a few demonstrations in the context without any parameter updates. However, their capabilities of in-context learning are limited by the model architecture: 1) the use of demonstrations is constrained by a maximum sentence length due to positional embeddings; 2) the quadratic complexity of attention hinders users from using more demonstrations efficiently; 3) LLMs are shown to be sensitive to the order of the demonstrations. In this work, we tackle these challenges by proposing a better architectural design for in-context learning. We propose SAICL (Structured Attention for In-Context Learning), which replaces the full-attention by a structured attention mechanism designed for in-context learning, and removes unnecessary dependencies between individual demonstrations, while making the model invariant to the permutation of demonstrations. We evaluate SAICL in a meta-training framework and show that SAICL achieves comparable or better performance than full attention while obtaining up to 3.4x inference speed-up. SAICL also consistently outperforms a strong Fusion-in-Decoder (FiD) baseline which processes each demonstration independently. Finally, thanks to its linear nature, we demonstrate that SAICL can easily scale to hundreds of demonstrations with continuous performance gains with scaling.

Viaarxiv icon

Sample Complexity for Quadratic Bandits: Hessian Dependent Bounds and Optimal Algorithms

Jun 22, 2023
Qian Yu, Yining Wang, Baihe Huang, Qi Lei, Jason D. Lee

In stochastic zeroth-order optimization, a problem of practical relevance is understanding how to fully exploit the local geometry of the underlying objective function. We consider a fundamental setting in which the objective function is quadratic, and provide the first tight characterization of the optimal Hessian-dependent sample complexity. Our contribution is twofold. First, from an information-theoretic point of view, we prove tight lower bounds on Hessian-dependent complexities by introducing a concept called energy allocation, which captures the interaction between the searching algorithm and the geometry of objective functions. A matching upper bound is obtained by solving the optimal energy spectrum. Then, algorithmically, we show the existence of a Hessian-independent algorithm that universally achieves the asymptotic optimal sample complexities for all Hessian instances. The optimal sample complexities achieved by our algorithm remain valid for heavy-tailed noise distributions, which are enabled by a truncation method.

Viaarxiv icon

Solving Robust MDPs through No-Regret Dynamics

May 30, 2023
Etash Kumar Guha, Jason D. Lee

Reinforcement Learning is a powerful framework for training agents to navigate different situations, but it is susceptible to changes in environmental dynamics. However, solving Markov Decision Processes that are robust to changes is difficult due to nonconvexity and size of action or state spaces. While most works have analyzed this problem by taking different assumptions on the problem, a general and efficient theoretical analysis is still missing. However, we generate a simple framework for improving robustness by solving a minimax iterative optimization problem where a policy player and an environmental dynamics player are playing against each other. Leveraging recent results in online nonconvex learning and techniques from improving policy gradient methods, we yield an algorithm that maximizes the robustness of the Value Function on the order of $\mathcal{O}\left(\frac{1}{T^{\frac{1}{2}}}\right)$ where $T$ is the number of iterations of the algorithm.

Viaarxiv icon

How to Query Human Feedback Efficiently in RL?

May 29, 2023
Wenhao Zhan, Masatoshi Uehara, Wen Sun, Jason D. Lee

Reinforcement Learning with Human Feedback (RLHF) is a paradigm in which an RL agent learns to optimize a task using pair-wise preference-based feedback over trajectories, rather than explicit reward signals. While RLHF has demonstrated practical success in fine-tuning language models, existing empirical work does not address the challenge of how to efficiently sample trajectory pairs for querying human feedback. In this study, we propose an efficient sampling approach to acquiring exploratory trajectories that enable accurate learning of hidden reward functions before collecting any human feedback. Theoretical analysis demonstrates that our algorithm requires less human feedback for learning the optimal policy under preference-based models with linear parameterization and unknown transitions, compared to the existing literature. Specifically, our framework can incorporate linear and low-rank MDPs. Additionally, we investigate RLHF with action-based comparison feedback and introduce an efficient querying algorithm tailored to this scenario.

Viaarxiv icon

Reward Collapse in Aligning Large Language Models

May 28, 2023
Ziang Song, Tianle Cai, Jason D. Lee, Weijie J. Su

Figure 1 for Reward Collapse in Aligning Large Language Models
Figure 2 for Reward Collapse in Aligning Large Language Models
Figure 3 for Reward Collapse in Aligning Large Language Models
Figure 4 for Reward Collapse in Aligning Large Language Models

The extraordinary capabilities of large language models (LLMs) such as ChatGPT and GPT-4 are in part unleashed by aligning them with reward models that are trained on human preferences, which are often represented as rankings of responses to prompts. In this paper, we document the phenomenon of \textit{reward collapse}, an empirical observation where the prevailing ranking-based approach results in an \textit{identical} reward distribution \textit{regardless} of the prompts during the terminal phase of training. This outcome is undesirable as open-ended prompts like ``write a short story about your best friend'' should yield a continuous range of rewards for their completions, while specific prompts like ``what is the capital of New Zealand'' should generate either high or low rewards. Our theoretical investigation reveals that reward collapse is primarily due to the insufficiency of the ranking-based objective function to incorporate prompt-related information during optimization. This insight allows us to derive closed-form expressions for the reward distribution associated with a set of utility functions in an asymptotic regime. To overcome reward collapse, we introduce a prompt-aware optimization scheme that provably admits a prompt-dependent reward distribution within the interpolating regime. Our experimental results suggest that our proposed prompt-aware utility functions significantly alleviate reward collapse during the training of reward models.

Viaarxiv icon

Fine-Tuning Language Models with Just Forward Passes

May 27, 2023
Sadhika Malladi, Tianyu Gao, Eshaan Nichani, Alex Damian, Jason D. Lee, Danqi Chen, Sanjeev Arora

Figure 1 for Fine-Tuning Language Models with Just Forward Passes
Figure 2 for Fine-Tuning Language Models with Just Forward Passes
Figure 3 for Fine-Tuning Language Models with Just Forward Passes
Figure 4 for Fine-Tuning Language Models with Just Forward Passes

Fine-tuning language models (LMs) has yielded success on diverse downstream tasks, but as LMs grow in size, backpropagation requires a prohibitively large amount of memory. Zeroth-order (ZO) methods can in principle estimate gradients using only two forward passes but are theorized to be catastrophically slow for optimizing large models. In this work, we propose a memory-efficient zerothorder optimizer (MeZO), adapting the classical ZO-SGD method to operate in-place, thereby fine-tuning LMs with the same memory footprint as inference. For example, with a single A100 80GB GPU, MeZO can train a 30-billion parameter model, whereas fine-tuning with backpropagation can train only a 2.7B LM with the same budget. We conduct comprehensive experiments across model types (masked and autoregressive LMs), model scales (up to 66B), and downstream tasks (classification, multiple-choice, and generation). Our results demonstrate that (1) MeZO significantly outperforms in-context learning and linear probing; (2) MeZO achieves comparable performance to fine-tuning with backpropagation across multiple tasks, with up to 12x memory reduction; (3) MeZO is compatible with both full-parameter and parameter-efficient tuning techniques such as LoRA and prefix tuning; (4) MeZO can effectively optimize non-differentiable objectives (e.g., maximizing accuracy or F1). We support our empirical findings with theoretical insights, highlighting how adequate pre-training and task prompts enable MeZO to fine-tune huge models, despite classical ZO analyses suggesting otherwise.

* Code available at 
Viaarxiv icon

Provable Offline Reinforcement Learning with Human Feedback

May 24, 2023
Wenhao Zhan, Masatoshi Uehara, Nathan Kallus, Jason D. Lee, Wen Sun

In this paper, we investigate the problem of offline reinforcement learning with human feedback where feedback is available in the form of preference between trajectory pairs rather than explicit rewards. Our proposed algorithm consists of two main steps: (1) estimate the implicit reward using Maximum Likelihood Estimation (MLE) with general function approximation from offline data and (2) solve a distributionally robust planning problem over a confidence set around the MLE. We consider the general reward setting where the reward can be defined over the whole trajectory and provide a novel guarantee that allows us to learn any target policy with a polynomial number of samples, as long as the target policy is covered by the offline data. This guarantee is the first of its kind with general function approximation. To measure the coverage of the target policy, we introduce a new single-policy concentrability coefficient, which can be upper bounded by the per-trajectory concentrability coefficient. We also establish lower bounds that highlight the necessity of such concentrability and the difference from standard RL, where state-action-wise rewards are directly observed. We further extend and analyze our algorithm when the feedback is given over action pairs.

Viaarxiv icon