Abstract:We present Clin-JEPA, a multi-phase co-training framework for joint-embedding predictive (JEPA) pretraining on EHR patient trajectories. JEPA architectures have enabled latent-space planning in robotics and high-quality representation learning in vision, but extending the paradigm to EHR data -- to obtain a single backbone that simultaneously forecasts patient trajectories and serves diverse downstream risk-prediction tasks without per-task fine-tuning -- remains an open challenge. Existing JEPA frameworks either discard the predictor after pretraining (I-JEPA, V-JEPA) or train it on a frozen pretrained encoder (V-JEPA 2-AC), leaving the encoder unaware of the rollout signal that the retained predictor must use at inference; co-training the encoder and predictor under a shared JEPA prediction objective would supply this grounding, but naïve co-training is unstable, with representation collapse and online/target drift causing autoregressive rollout to diverge. Clin-JEPA's five-phase pretraining curriculum -- predictor warmup, joint refinement, EMA target alignment, hard sync, and predictor finalization -- addresses each failure mode by phase, stably co-training a Qwen3-8B-based encoder and a 92M-parameter latent trajectory predictor. On MIMIC-IV ICU data, three independent evaluations support the framework: (1) latent $\ell_1$ rollout drift uniquely converges ($-$15.7%) over 48-hour horizons while baselines and ablations diverge (+3% to +4951%); (2) the encoder learns a clinically discriminative latent geometry (deteriorating-patient cohorts displace 4.83$\times$ further than stable patients in latent space, vs $\leq$2.62$\times$ for baseline encoders); (3) a single backbone outperforms strong tabular and sequence baselines on multi-task downstream evaluation. Clin-JEPA achieves mean AUROC 0.851 on ICareFM EEP and 0.883 on 8 binary risk tasks (+0.038 and +0.041 vs baseline average).
Abstract:Machine learning (ML) models are increasingly pivotal in automating clinical decisions. Yet, a glaring oversight in prior research has been the lack of proper processing of Electronic Medical Record (EMR) data in the clinical context for errors and outliers. Addressing this oversight, we introduce an innovative projections-based method that seamlessly integrates clinical expertise as domain constraints, generating important meta-data that can be used in ML workflows. In particular, by using high-dimensional mixed-integer programs that capture physiological and biological constraints on patient vitals and lab values, we can harness the power of mathematical "projections" for the EMR data to correct patient data. Consequently, we measure the distance of corrected data from the constraints defining a healthy range of patient data, resulting in a unique predictive metric we term as "trust-scores". These scores provide insight into the patient's health status and significantly boost the performance of ML classifiers in real-life clinical settings. We validate the impact of our framework in the context of early detection of sepsis using ML. We show an AUROC of 0.865 and a precision of 0.922, that surpasses conventional ML models without such projections.