Abstract:Generative models are often trained with a next-token prediction objective, yet many downstream applications require the ability to estimate or control sequence-level properties. Next-token prediction can lead to overfitting of local patterns during training, underfitting of global structure, and requires significant downstream modifications or expensive sampling to guide or predict the global attributes of generated samples at inference time. Here, we introduce Conditional Attribute Transformers, a novel method for jointly estimating the next-token probability and the value of an attribute conditional on each potential next token selection. This framework enables three critical capabilities within a single forward pass, without modification of the input sequence: (1) per-token credit assignment across an entire sequence, by identifying how each token in a sequence is associated with an attribute's value; (2) counterfactual analysis, by quantifying attribute differences conditional on alternative next token choices; (3) steerable generation, by decoding sequences based on a combination of next-token and attribute likelihoods. Our approach achieves state of the art performance on sparse reward tasks, improves next-token prediction at sufficient model sizes, estimates attribute probabilities orders of magnitude faster than sampling, and can guide decoding of autoregressive sequence models on a range of language tasks.
Abstract:Real-world processes often generate data that are a mix of categorical and numeric values that are recorded at irregular and informative intervals. Discrete token-based approaches are limited in numeric representation capacity while methods like neural ordinary differential equations are not well suited for categorical data or informative sampling and require augmentation to handle certain classes of trajectories. Here, we present multivariateGPT, a single architecture for modeling sequences of mixed categorical (including tokenized text) and numeric data. This is accomplished with an autoregressive sequence decomposition, embedding scheme, and loss function that extend the next token prediction task to likelihood estimation of the joint distribution of next token class and value. We demonstrate how this approach can efficiently learn to generalize patterns in simple physical systems and model complex time series including electrocardiograms and multivariate electronic health record data. This work extends the utility of transformer based models to additional classes of data.