Abstract:The cost of simulator evaluations is a key practical bottleneck for Simulation Based Inference (SBI). In hierarchical settings with shared global parameters and exchangeable site-level parameters and observations, this structure can be exploited to improve simulation efficiency. Existing hierarchical SBI approaches factorise the posterior yet still simulate across multiple sites per training sample; We instead explore likelihood factorisation (LF) to train from single-site simulations. In LF sampling we learn a per-site neural surrogate of the simulator and then assemble synthetic multi-site observations to amortise inference for the full hierarchical posterior. Building on this, we propose Tokenised Flow Matching for Posterior Estimation (TFMPE), a tokenised flow matching approach that supports function-valued observations through likelihood factorisation. To enable systematic evaluation, we introduce a benchmark for hierarchical SBI. We validate TFMPE on this benchmark and on realistic infectious disease and computational fluid dynamics models, finding well-calibrated posteriors while reducing computational cost.
Abstract:State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical $O(1)$ state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill ($15%$ MFU) and up to $64%$ bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.
Abstract:As large language models engage in extended reasoning tasks, they accumulate significant state -- architectural mappings, trade-off decisions, codebase conventions -- within the context window. This understanding is lost when sessions reach context limits and undergo lossy compaction. We propose Contextual Memory Virtualisation (CMV), a system that treats accumulated LLM understanding as version-controlled state. Borrowing from operating system virtual memory, CMV models session history as a Directed Acyclic Graph (DAG) with formally defined snapshot, branch, and trim primitives that enable context reuse across independent parallel sessions. We introduce a three-pass structurally lossless trimming algorithm that preserves every user message and assistant response verbatim while reducing token counts by a mean of 20% and up to 86% for sessions with significant overhead by stripping mechanical bloat such as raw tool outputs, base64 images, and metadata. A single-user case-study evaluation across 76 real-world coding sessions demonstrates that trimming remains economically viable under prompt caching, with the strongest gains in mixed tool-use sessions, which average 39% reduction and reach break-even within 10 turns. A reference implementation is available at https://github.com/CosmoNaught/claude-code-cmv.