Abstract:Weak-to-strong generalization (W2SG) refers to the phenomenon where a strong student model, trained on a dataset labeled by a weak teacher, ultimately outperforms the teacher on the target task. Recent studies attribute this performance gain to the prediction misfit between the student and teacher models. In this work, we theoretically investigate the emergence of W2SG through a generalized bias-variance decomposition of Bregman divergence. Specifically, we show that the expected population risk gap between the student and teacher is quantified by the expected misfit between the two models. While this aligns with previous results, our analysis removes several restrictive assumptions, most notably, the convexity of the student's hypothesis class, required in earlier works. Moreover, we show that W2SG is more likely to emerge when the student model approximates its posterior mean teacher, rather than mimicking an individual teacher. Using a concrete example, we demonstrate that if the student model has significantly larger capacity than the teacher, it can indeed converge to this posterior mean. Our analysis also suggests that avoiding overfitting to the teacher's supervision and reducing the entropy of student's prediction further facilitate W2SG. In addition, we show that the reverse cross-entropy loss, unlike the standard forward cross-entropy, is less sensitive to the predictive uncertainty of the teacher. Finally, we empirically verify our theoretical insights and demonstrate that incorporating the reverse cross-entropy loss consistently improves student performance.
Abstract:Large Language Models (LLMs), built on Transformer architectures, exhibit remarkable generalization across a wide range of tasks. However, fine-tuning these models for specific tasks remains resource-intensive due to their extensive parameterization. In this paper, we investigate two remarkable phenomena observed during the fine-tuning of LLMs, particularly focusing on the attention mechanism: (1) Different Impact, optimizing the $\mathbf{W}_v$ matrix significantly improves performance over optimizing the $\mathbf{W}_k$ matrix. Fine-tuning only the $\mathbf{W}_q$ and $\mathbf{W}_v$ matrices is computationally efficient, delivering results that are comparable to, or even better than, fine-tuning all three matrices $\mathbf{W}_q$, $\mathbf{W}_k$, and $\mathbf{W}_v$. (2) Efficient Convergence, employing distinct learning rates for these matrices is crucial for optimal performance, with a higher learning rate for the $\mathbf{W}_v$ matrix expediting convergence. However, theoretical analyses of these phenomena are still relatively limited. We present a theoretical analysis of these phenomena from two perspectives: (i) Generalization, where we demonstrate that fine-tuning only $\mathbf{W}_q$ and $\mathbf{W}_v$ improves generalization bounds, enhances memory efficiency, and (ii) Optimization, where we emphasize that the feature learning of the attention mechanism is efficient, particularly when using distinct learning rates for the matrices, which leads to more effective fine-tuning. Building on these insights, we propose a new strategy that improves fine-tuning efficiency in terms of both storage and time. Experimental results on benchmark datasets validate the effectiveness of this approach, supporting our theoretical findings. Our analysis lays the theoretical groundwork for configuring and improving lightweight algorithms in LLMs fine-tuning.