Alert button
Picture for Martin Jaggi

Martin Jaggi

Alert button

MEDITRON-70B: Scaling Medical Pretraining for Large Language Models

Nov 27, 2023
Zeming Chen, Alejandro Hernández Cano, Angelika Romanou, Antoine Bonnet, Kyle Matoba, Francesco Salvi, Matteo Pagliardini, Simin Fan, Andreas Köpf, Amirkeivan Mohtashami, Alexandre Sallinen, Alireza Sakhaeirad, Vinitra Swamy, Igor Krawczuk, Deniz Bayazit, Axel Marmet, Syrielle Montariol, Mary-Anne Hartley, Martin Jaggi, Antoine Bosselut

Large language models (LLMs) can potentially democratize access to medical knowledge. While many efforts have been made to harness and improve LLMs' medical knowledge and reasoning capacities, the resulting models are either closed-source (e.g., PaLM, GPT-4) or limited in scale (<= 13B parameters), which restricts their abilities. In this work, we improve access to large-scale medical LLMs by releasing MEDITRON: a suite of open-source LLMs with 7B and 70B parameters adapted to the medical domain. MEDITRON builds on Llama-2 (through our adaptation of Nvidia's Megatron-LM distributed trainer), and extends pretraining on a comprehensively curated medical corpus, including selected PubMed articles, abstracts, and internationally-recognized medical guidelines. Evaluations using four major medical benchmarks show significant performance gains over several state-of-the-art baselines before and after task-specific finetuning. Overall, MEDITRON achieves a 6% absolute performance gain over the best public baseline in its parameter class and 3% over the strongest baseline we finetuned from Llama-2. Compared to closed-source LLMs, MEDITRON-70B outperforms GPT-3.5 and Med-PaLM and is within 5% of GPT-4 and 10% of Med-PaLM-2. We release our code for curating the medical pretraining corpus and the MEDITRON model weights to drive open-source development of more capable medical LLMs.

Viaarxiv icon

Controllable Topic-Focused Abstractive Summarization

Nov 12, 2023
Seyed Ali Bahrainian, Martin Jaggi, Carsten Eickhoff

Controlled abstractive summarization focuses on producing condensed versions of a source article to cover specific aspects by shifting the distribution of generated text towards a desired style, e.g., a set of topics. Subsequently, the resulting summaries may be tailored to user-defined requirements. This paper presents a new Transformer-based architecture capable of producing topic-focused summaries. The architecture modifies the cross-attention mechanism of the Transformer to bring topic-focus control to the generation process while not adding any further parameters to the model. We show that our model sets a new state of the art on the NEWTS dataset in terms of topic-focused abstractive summarization as well as a topic-prevalence score. Moreover, we show via extensive experiments that our proposed topical cross-attention mechanism can be plugged into various Transformer models, such as BART and T5, improving their performance on the CNN/Dailymail and XSum benchmark datasets for abstractive summarization. This is achieved via fine-tuning, without requiring training from scratch. Finally, we show through human evaluation that our model generates more faithful summaries outperforming the state-of-the-art Frost model.

Viaarxiv icon

DoGE: Domain Reweighting with Generalization Estimation

Oct 23, 2023
Simin Fan, Matteo Pagliardini, Martin Jaggi

The coverage and composition of the pretraining data corpus significantly impacts the generalization ability of large language models. Conventionally, the pretraining corpus is composed of various source domains (e.g. CommonCrawl, Wikipedia, Github etc.) according to certain sampling probabilities (domain weights). However, current methods lack a principled way to optimize domain weights for ultimate goal for generalization. We propose DOmain reweighting with Generalization Estimation (DoGE), where we reweigh the sampling probability from each domain based on its contribution to the final generalization objective assessed by a gradient-based generalization estimation function. First, we train a small-scale proxy model with a min-max optimization to obtain the reweighted domain weights. At each step, the domain weights are updated to maximize the overall generalization gain by mirror descent. Finally we use the obtained domain weights to train a larger scale full-size language model. On SlimPajama-6B dataset, with universal generalization objective, DoGE achieves better average perplexity and zero-shot reasoning accuracy. On out-of-domain generalization tasks, DoGE reduces perplexity on the target domain by a large margin. We further apply a parameter-selection scheme which improves the efficiency of generalization estimation.

Viaarxiv icon

Irreducible Curriculum for Language Model Pretraining

Oct 23, 2023
Simin Fan, Martin Jaggi

Automatic data selection and curriculum design for training large language models is challenging, with only a few existing methods showing improvements over standard training. Furthermore, current schemes focus on domain-level selection, overlooking the more fine-grained contributions of each individual training point. It is difficult to apply traditional datapoint selection methods on large language models: most online batch selection methods perform two-times forward or backward passes, which introduces considerable extra costs with large-scale models. To mitigate these obstacles, we propose irreducible curriculum as a curriculum learning algorithm for language model pretraining, which prioritizes samples with higher learnability. Specifically, to avoid prohibitive extra computation overhead, we simulate the sample loss along the main model's training trajectory using a small-scale proxy model. Our experiments on the RedPajama-1B dataset demonstrate a consistent improvement on validation perplexity across all 7 domains compared to random uniform baseline and the anti-curriculum strategy. Our method also reduces the sharpness of the network and illustrates a better 5-shot accuracy on MMLU benchmarks.

Viaarxiv icon

LASER: Linear Compression in Wireless Distributed Optimization

Oct 19, 2023
Ashok Vardhan Makkuva, Marco Bondaschi, Thijs Vogels, Martin Jaggi, Hyeji Kim, Michael C. Gastpar

Data-parallel SGD is the de facto algorithm for distributed optimization, especially for large scale machine learning. Despite its merits, communication bottleneck is one of its persistent issues. Most compression schemes to alleviate this either assume noiseless communication links, or fail to achieve good performance on practical tasks. In this paper, we close this gap and introduce LASER: LineAr CompreSsion in WirEless DistRibuted Optimization. LASER capitalizes on the inherent low-rank structure of gradients and transmits them efficiently over the noisy channels. Whilst enjoying theoretical guarantees similar to those of the classical SGD, LASER shows consistent gains over baselines on a variety of practical benchmarks. In particular, it outperforms the state-of-the-art compression schemes on challenging computer vision and GPT language modeling tasks. On the latter, we obtain $50$-$64 \%$ improvement in perplexity over our baselines for noisy channels.

Viaarxiv icon

CoTFormer: More Tokens With Attention Make Up For Less Depth

Oct 16, 2023
Amirkeivan Mohtashami, Matteo Pagliardini, Martin Jaggi

The race to continually develop ever larger and deeper foundational models is underway. However, techniques like the Chain-of-Thought (CoT) method continue to play a pivotal role in achieving optimal downstream performance. In this work, we establish an approximate parallel between using chain-of-thought and employing a deeper transformer. Building on this insight, we introduce CoTFormer, a transformer variant that employs an implicit CoT-like mechanism to achieve capacity comparable to a deeper model. Our empirical findings demonstrate the effectiveness of CoTFormers, as they significantly outperform larger standard transformers.

Viaarxiv icon

MultiModN- Multimodal, Multi-Task, Interpretable Modular Networks

Sep 25, 2023
Vinitra Swamy, Malika Satayeva, Jibril Frej, Thierry Bossy, Thijs Vogels, Martin Jaggi, Tanja Käser, Mary-Anne Hartley

Figure 1 for MultiModN- Multimodal, Multi-Task, Interpretable Modular Networks
Figure 2 for MultiModN- Multimodal, Multi-Task, Interpretable Modular Networks
Figure 3 for MultiModN- Multimodal, Multi-Task, Interpretable Modular Networks
Figure 4 for MultiModN- Multimodal, Multi-Task, Interpretable Modular Networks

Predicting multiple real-world tasks in a single model often requires a particularly diverse feature space. Multimodal (MM) models aim to extract the synergistic predictive potential of multiple data types to create a shared feature space with aligned semantic meaning across inputs of drastically varying sizes (i.e. images, text, sound). Most current MM architectures fuse these representations in parallel, which not only limits their interpretability but also creates a dependency on modality availability. We present MultiModN, a multimodal, modular network that fuses latent representations in a sequence of any number, combination, or type of modality while providing granular real-time predictive feedback on any number or combination of predictive tasks. MultiModN's composable pipeline is interpretable-by-design, as well as innately multi-task and robust to the fundamental issue of biased missingness. We perform four experiments on several benchmark MM datasets across 10 real-world tasks (predicting medical diagnoses, academic performance, and weather), and show that MultiModN's sequential MM fusion does not compromise performance compared with a baseline of parallel fusion. By simulating the challenging bias of missing not-at-random (MNAR), this work shows that, contrary to MultiModN, parallel fusion baselines erroneously learn MNAR and suffer catastrophic failure when faced with different patterns of MNAR at inference. To the best of our knowledge, this is the first inherently MNAR-resistant approach to MM modeling. In conclusion, MultiModN provides granular insights, robustness, and flexibility without compromising performance.

* Accepted as a full paper at NeurIPS 2023 in New Orleans, USA 
Viaarxiv icon

Layerwise Linear Mode Connectivity

Jul 13, 2023
Linara Adilova, Asja Fischer, Martin Jaggi

Figure 1 for Layerwise Linear Mode Connectivity
Figure 2 for Layerwise Linear Mode Connectivity
Figure 3 for Layerwise Linear Mode Connectivity
Figure 4 for Layerwise Linear Mode Connectivity

In the federated setup one performs an aggregation of separate local models multiple times during training in order to obtain a stronger global model; most often aggregation is a simple averaging of the parameters. Understanding when and why averaging works in a non-convex setup, such as federated deep learning, is an open challenge that hinders obtaining highly performant global models. On i.i.d.~datasets federated deep learning with frequent averaging is successful. The common understanding, however, is that during the independent training models are drifting away from each other and thus averaging may not work anymore after many local parameter updates. The problem can be seen from the perspective of the loss surface: for points on a non-convex surface the average can become arbitrarily bad. The assumption of local convexity, often used to explain the success of federated averaging, contradicts to the empirical evidence showing that high loss barriers exist between models from the very beginning of the learning, even when training on the same data. Based on the observation that the learning process evolves differently in different layers, we investigate the barrier between models in a layerwise fashion. Our conjecture is that barriers preventing from successful federated training are caused by a particular layer or group of layers.

* HLD 2023: 1st Workshop on High-dimensional Learning Dynamics, ICML 2023, Hawaii, USA 
Viaarxiv icon

Shuffle SGD is Always Better than SGD: Improved Analysis of SGD with Arbitrary Data Orders

Jun 15, 2023
Anastasia Koloskova, Nikita Doikov, Sebastian U. Stich, Martin Jaggi

Figure 1 for Shuffle SGD is Always Better than SGD: Improved Analysis of SGD with Arbitrary Data Orders
Figure 2 for Shuffle SGD is Always Better than SGD: Improved Analysis of SGD with Arbitrary Data Orders
Figure 3 for Shuffle SGD is Always Better than SGD: Improved Analysis of SGD with Arbitrary Data Orders
Figure 4 for Shuffle SGD is Always Better than SGD: Improved Analysis of SGD with Arbitrary Data Orders

Stochastic Gradient Descent (SGD) algorithms are widely used in optimizing neural networks, with Random Reshuffling (RR) and Single Shuffle (SS) being popular choices for cycling through random or single permutations of the training data. However, the convergence properties of these algorithms in the non-convex case are not fully understood. Existing results suggest that, in realistic training scenarios where the number of epochs is smaller than the training set size, RR may perform worse than SGD. In this paper, we analyze a general SGD algorithm that allows for arbitrary data orderings and show improved convergence rates for non-convex functions. Specifically, our analysis reveals that SGD with random and single shuffling is always faster or at least as good as classical SGD with replacement, regardless of the number of iterations. Overall, our study highlights the benefits of using SGD with random/single shuffling and provides new insights into its convergence properties for non-convex optimization.

Viaarxiv icon

Provably Personalized and Robust Federated Learning

Jun 14, 2023
Mariel Werner, Lie He, Sai Praneeth Karimireddy, Michael Jordan, Martin Jaggi

Figure 1 for Provably Personalized and Robust Federated Learning
Figure 2 for Provably Personalized and Robust Federated Learning
Figure 3 for Provably Personalized and Robust Federated Learning
Figure 4 for Provably Personalized and Robust Federated Learning

Clustering clients with similar objectives and learning a model per cluster is an intuitive and interpretable approach to personalization in federated learning. However, doing so with provable and optimal guarantees has remained an open challenge. In this work, we formalize personalized federated learning as a stochastic optimization problem where the stochastic gradients on a client may correspond to one of $K$ distributions. In such a setting, we show that using i) a simple thresholding-based clustering algorithm, and ii) local client gradients obtains optimal convergence guarantees. In fact, our rates asymptotically match those obtained if we knew the true underlying clustering of the clients. Furthermore, our algorithms are provably robust in the Byzantine setting where some fraction of the gradients are corrupted.

Viaarxiv icon