University of Edinburgh
Abstract:State-space language models such as Mamba and gated linear attention (GLA) offer efficient alternatives to transformers due to their linear complexity and parallel training, but often lack the expressivity and robust state-tracking needed for complex reasoning. We address these limitations by reframing sequence modelling through a probabilistic lens, using Bayesian filters as a core primitive. While classical filters such as Kalman filters provide principled state estimation and uncertainty tracking, they are typically viewed as inherently sequential. We show that reparameterising the Kalman filter in information form enables its updates to be computed via an associative scan, allowing efficient parallel training. Building on this insight, we introduce the Kalman Linear Attention (KLA) layer, a neural sequence-modelling primitive that performs time-parallel probabilistic inference while maintaining explicit belief-state uncertainty. KLA offers strictly more expressive nonlinear updates and gating than GLA variants while retaining their computational advantages. On language modelling tasks, KLA matches or outperforms modern SSMs and GLAs across representative discrete token-manipulation and state-tracking benchmarks.
Abstract:Probabilistic forecasting is increasingly critical across high-stakes domains, from finance and epidemiology to climate science. However, current evaluation frameworks lack a consensus metric and suffer from two critical flaws: they often assume independence across time steps or variables, and they demonstrably lack sensitivity to tail events, the very occurrences that are most pivotal in real-world decision-making. To address these limitations, we propose two kernel-based metrics: the signature maximum mean discrepancy (Sig-MMD) and our novel censored Sig-MMD (CSig-MMD). By leveraging the signature kernel, these metrics capture complex inter-variate and inter-temporal dependencies and remain robust to missing data. Furthermore, CSig-MMD introduces a censoring scheme that prioritizes a forecaster's capability to predict tail events while strictly maintaining properness, a vital property for a good scoring rule. These metrics enable a more reliable evaluation of direct multi-step forecasting, facilitating the development of more robust probabilistic algorithms.
Abstract:This paper proposes a suite of rationality measures and associated theory for reinforcement learning agents, a property increasingly critical yet rarely explored. We define an action in deployment to be perfectly rational if it maximises the hidden true value function in the steepest direction. The expected value discrepancy of a policy's actions against their rational counterparts, culminating over the trajectory in deployment, is defined to be expected rational risk; an empirical average version in training is also defined. Their difference, termed as rational risk gap, is decomposed into (1) an extrinsic component caused by environment shifts between training and deployment, and (2) an intrinsic one due to the algorithm's generalisability in a dynamic environment. They are upper bounded by, respectively, (1) the $1$-Wasserstein distance between transition kernels and initial state distributions in training and deployment, and (2) the empirical Rademacher complexity of the value function class. Our theory suggests hypotheses on the benefits from regularisers (including layer normalisation, $\ell_2$ regularisation, and weight normalisation) and domain randomisation, as well as the harm from environment shifts. Experiments are in full agreement with these hypotheses. The code is available at https://github.com/EVIEHub/Rationality.
Abstract:The ecosystem behind foundation model development today is highly centralized and limited to large-scale cloud data center operators: training foundation models is costly, needing immense compute resources. Decentralized foundation model training across edge devices, leveraging their spare compute, promises a democratized alternative. However, existing edge-training approaches fall short: they struggle to match cloud-based training performance, exhibit limited scalability with model size, exceed device memory capacity, and have prohibitive communication overhead. They also fail to satisfactorily handle device heterogeneity and dynamism. We introduce a new paradigm, Cleave, which finely partitions training operations through a novel selective hybrid tensor parallelism method. Together with a parameter server centric training framework, Cleave copes with device memory limits and avoids communication bottlenecks, thereby enabling efficient training of large models on par with the cloud. Further, with a cost optimization model to guide device selection and training workload distribution, Cleave effectively accounts for device heterogeneity and churn. Our evaluations show that Cleave matches cloud-based GPU training by scaling efficiently to larger models and thousands of devices, supporting up to 8x more devices than baseline edge-training approaches. It outperforms state-of-the-art edge training methods by up to a factor of 10 in per-batch training time and efficiently handles device failures, achieving at least 100x faster recovery than prior methods.
Abstract:A fundamental challenge in developing general learning algorithms is their tendency to forget past knowledge when adapting to new data. Addressing this problem requires a principled understanding of forgetting; yet, despite decades of study, no unified definition has emerged that provides insights into the underlying dynamics of learning. We propose an algorithm- and task-agnostic theory that characterises forgetting as a lack of self-consistency in a learner's predictive distribution over future experiences, manifesting as a loss of predictive information. Our theory naturally yields a general measure of an algorithm's propensity to forget. To validate the theory, we design a comprehensive set of experiments that span classification, regression, generative modelling, and reinforcement learning. We empirically demonstrate how forgetting is present across all learning settings and plays a significant role in determining learning efficiency. Together, these results establish a principled understanding of forgetting and lay the foundation for analysing and improving the information retention capabilities of general learning algorithms.




Abstract:World models are a powerful paradigm in AI and robotics, enabling agents to reason about the future by predicting visual observations or compact latent states. The 1X World Model Challenge introduces an open-source benchmark of real-world humanoid interaction, with two complementary tracks: sampling, focused on forecasting future image frames, and compression, focused on predicting future discrete latent codes. For the sampling track, we adapt the video generation foundation model Wan-2.2 TI2V-5B to video-state-conditioned future frame prediction. We condition the video generation on robot states using AdaLN-Zero, and further post-train the model using LoRA. For the compression track, we train a Spatio-Temporal Transformer model from scratch. Our models achieve 23.0 dB PSNR in the sampling task and a Top-500 CE of 6.6386 in the compression task, securing 1st place in both challenges.
Abstract:Large Multimodal Models (LMMs) often rely on in-context learning (ICL) to perform new tasks with minimal supervision. However, ICL performance, especially in smaller LMMs, is inconsistent and does not always improve monotonically with increasing examples. We hypothesize that this occurs due to the LMM being overwhelmed by additional information present in the image embeddings, which is not required for the downstream task. To address this, we propose a meta-learning approach that provides an alternative for inducing few-shot capabilities in LMMs, using a fixed set of soft prompts that are distilled from task-relevant image features and can be adapted at test time using a few examples. To facilitate this distillation, we introduce an attention-mapper module that can be easily integrated with the popular LLaVA v1.5 architecture and is jointly learned with soft prompts, enabling task adaptation in LMMs under low-data regimes with just a few gradient steps. Evaluation on the VL-ICL Bench shows that our method consistently outperforms ICL and related prompt-tuning approaches, even under image perturbations, improving task induction and reasoning across visual question answering tasks.
Abstract:Deep reinforcement learning has achieved remarkable success in learning control policies from pixels across a wide range of tasks, yet its application remains hindered by low sample efficiency, requiring significantly more environment interactions than humans to reach comparable performance. Model-based reinforcement learning (MBRL) offers a solution by leveraging learnt world models to generate simulated experience, thereby improving sample efficiency. However, in visually complex environments, small or dynamic elements can be critical for decision-making. Yet, traditional MBRL methods in pixel-based environments typically rely on auto-encoding with an $L_2$ loss, which is dominated by large areas and often fails to capture decision-relevant details. To address these limitations, we propose an object-centric MBRL pipeline, which integrates recent advances in computer vision to allow agents to focus on key decision-related elements. Our approach consists of four main steps: (1) annotating key objects related to rewards and goals with segmentation masks, (2) extracting object features using a pre-trained, frozen foundation vision model, (3) incorporating these object features with the raw observations to predict environmental dynamics, and (4) training the policy using imagined trajectories generated by this object-centric world model. Building on the efficient MBRL algorithm STORM, we call this pipeline OC-STORM. We demonstrate OC-STORM's practical value in overcoming the limitations of conventional MBRL approaches on both Atari games and the visually complex game Hollow Knight.
Abstract:Training neural network classifiers on datasets contaminated with noisy labels significantly increases the risk of overfitting. Thus, effectively implementing Early Stopping in noisy label environments is crucial. Under ideal circumstances, Early Stopping utilises a validation set uncorrupted by label noise to effectively monitor generalisation during training. However, obtaining a noise-free validation dataset can be costly and challenging to obtain. This study establishes that, in many typical learning environments, a noise-free validation set is not necessary for effective Early Stopping. Instead, near-optimal results can be achieved by monitoring accuracy on a noisy dataset - drawn from the same distribution as the noisy training set. Referred to as `Noisy Early Stopping' (NES), this method simplifies and reduces the cost of implementing Early Stopping. We provide theoretical insights into the conditions under which this method is effective and empirically demonstrate its robust performance across standard benchmarks using common loss functions.
Abstract:As a potential non-invasive biomarker for ischaemic stroke, intracranial arterial calcification (IAC) could be used for stroke risk assessment on CT head scans routinely acquired for other reasons (e.g. trauma, confusion). Artificial intelligence methods can support IAC scoring, but they have not yet been developed for clinical imaging. Large heterogeneous clinical CT datasets are necessary for the training of such methods, but they exhibit expected and unexpected data anomalies. Using CTs from a large clinical trial, the third International Stroke Trial (IST-3), we propose a pipeline that uses as input non-enhanced CT scans to output regions of interest capturing selected large intracranial arteries for IAC scoring. Our method uses co-registration with templates. We focus on quality control, using information presence along the z-axis of the imaging to group and apply similarity measures (structural similarity index measure) to triage assessment of individual image series. Additionally, we propose superimposing thresholded binary masks of the series to inspect large quantities of data in parallel. We identify and exclude unrecoverable samples and registration failures. In total, our pipeline processes 10,659 CT series, rejecting 4,322 (41%) in the entire process, 1,450 (14% of the total) during quality control, and outputting 6,337 series. Our pipeline enables effective and efficient region of interest localisation for targeted IAC segmentation.