Causal discovery, the task of inferring causal structure from data, promises to accelerate scientific research, inform policy making, and more. However, the per-dataset nature of existing causal discovery algorithms renders them slow, data hungry, and brittle. Inspired by foundation models, we propose a causal discovery framework where a deep learning model is pretrained to resolve predictions from classical discovery algorithms run over smaller subsets of variables. This method is enabled by the observations that the outputs from classical algorithms are fast to compute for small problems, informative of (marginal) data structure, and their structure outputs as objects remain comparable across datasets. Our method achieves state-of-the-art performance on synthetic and realistic datasets, generalizes to data generating mechanisms not seen during training, and offers inference speeds that are orders of magnitude faster than existing models.
Vision Transformer (ViT) has emerged as a powerful architecture in the realm of modern computer vision. However, its application in certain imaging fields, such as microscopy and satellite imaging, presents unique challenges. In these domains, images often contain multiple channels, each carrying semantically distinct and independent information. Furthermore, the model must demonstrate robustness to sparsity in input channels, as they may not be densely available during training or testing. In this paper, we propose a modification to the ViT architecture that enhances reasoning across the input channels and introduce Hierarchical Channel Sampling (HCS) as an additional regularization technique to ensure robustness when only partial channels are presented during test time. Our proposed model, ChannelViT, constructs patch tokens independently from each input channel and utilizes a learnable channel embedding that is added to the patch tokens, similar to positional embeddings. We evaluate the performance of ChannelViT on ImageNet, JUMP-CP (microscopy cell imaging), and So2Sat (satellite imaging). Our results show that ChannelViT outperforms ViT on classification tasks and generalizes well, even when a subset of input channels is used during testing. Across our experiments, HCS proves to be a powerful regularizer, independent of the architecture employed, suggesting itself as a straightforward technique for robust ViT training. Lastly, we find that ChannelViT generalizes effectively even when there is limited access to all channels during training, highlighting its potential for multi-channel imaging under real-world conditions with sparse sensors. Our code is available at https://github.com/insitro/ChannelViT.
We present Contextual Vision Transformers (ContextViT), a method for producing robust feature representations for images exhibiting grouped structure such as covariates. ContextViT introduces an extra context token to encode group-specific information, allowing the model to explain away group-specific covariate structures while keeping core visual features shared across groups. Specifically, given an input image, Context-ViT maps images that share the same covariate into this context token appended to the input image tokens to capture the effects of conditioning the model on group membership. We furthermore introduce a context inference network to predict such tokens on the fly given a few samples from a group distribution, enabling ContextViT to generalize to new testing distributions at inference time. We illustrate the performance of ContextViT through a diverse range of applications. In supervised fine-tuning, we demonstrate that augmenting pre-trained ViTs with additional context conditioning leads to significant improvements in out-of-distribution generalization on iWildCam and FMoW. We also explored self-supervised representation learning with ContextViT. Our experiments on the Camelyon17 pathology imaging benchmark and the cpg-0000 microscopy imaging benchmark demonstrate that ContextViT excels in learning stable image featurizations amidst covariate shift, consistently outperforming its ViT counterpart.
Classifiers are biased when trained on biased datasets. As a remedy, we propose Learning to Split (ls), an algorithm for automatic bias detection. Given a dataset with input-label pairs, ls learns to split this dataset so that predictors trained on the training split generalize poorly to the testing split. This performance gap provides a proxy for measuring the degree of bias in the learned features and can therefore be used to reduce biases. Identifying non-generalizable splits is challenging as we don't have any explicit annotations about how to split. In this work, we show that the prediction correctness of the testing example can be used as a source of weak supervision: generalization performance will drop if we move examples that are predicted correctly away from the testing split, leaving only those that are mispredicted. We evaluate our approach on Beer Review, Waterbirds, CelebA and MNLI. Empirical results show that ls is able to generate astonishingly challenging splits that correlate with human-identified biases. Moreover, we demonstrate that combining robust learning algorithms (such as group DRO) with splits identified by ls enables automatic de-biasing. Compared with previous state-of-the-arts, we substantially improves the worst-group performance (23.4% on average) when the source of biases is unknown during training and validation.
We study transfer learning in the presence of spurious correlations. We experimentally demonstrate that directly transferring the stable feature extractor learned on the source task may not eliminate these biases for the target task. However, we hypothesize that the unstable features in the source task and those in the target task are directly related. By explicitly informing the target classifier of the source task's unstable features, we can regularize the biases in the target task. Specifically, we derive a representation that encodes the unstable features by contrasting different data environments in the source task. On the target task, we cluster data from this representation, and achieve robustness by minimizing the worst-case risk across all clusters. We evaluate our method on both text and image classifications. Empirical results demonstrate that our algorithm is able to maintain robustness on the target task, outperforming the best baseline by 22.9% in absolute accuracy across 12 transfer settings. Our code is available at https://github.com/YujiaBao/Tofu.
We propose Predict then Interpolate (PI), a simple algorithm for learning correlations that are stable across environments. The algorithm follows from the intuition that when using a classifier trained on one environment to make predictions on examples from another environment, its mistakes are informative as to which correlations are unstable. In this work, we prove that by interpolating the distributions of the correct predictions and the wrong predictions, we can uncover an oracle distribution where the unstable correlation vanishes. Since the oracle interpolation coefficients are not accessible, we use group distributionally robust optimization to minimize the worst-case risk across all such interpolations. We evaluate our method on both text classification and image classification. Empirical results demonstrate that our algorithm is able to learn robust classifiers (outperforms IRM by 23.85% on synthetic environments and 12.41% on natural environments). Our code and data are available at https://github.com/YujiaBao/Predict-then-Interpolate.
In this paper, we explore meta-learning for few-shot text classification. Meta-learning has shown strong performance in computer vision, where low-level patterns are transferable across learning tasks. However, directly applying this approach to text is challenging--words highly informative for one task may have little significance for another. Thus, rather than learning solely from words, our model also leverages their distributional signatures, which encode pertinent word occurrence patterns. Our model is trained within a meta-learning framework to map these signatures into attention scores, which are then used to weight the lexical representations of words. We demonstrate that our model consistently outperforms prototypical networks in both few-shot text classification and relation classification by a significant margin across six benchmark datasets (19.96% on average in 1-shot classification). Our code is available at https://github.com/YujiaBao/Distributional-Signatures.
PURPOSE: The medical literature relevant to germline genetics is growing exponentially. Clinicians need tools monitoring and prioritizing the literature to understand the clinical implications of the pathogenic genetic variants. We developed and evaluated two machine learning models to classify abstracts as relevant to the penetrance (risk of cancer for germline mutation carriers) or prevalence of germline genetic mutations. METHODS: We conducted literature searches in PubMed and retrieved paper titles and abstracts to create an annotated dataset for training and evaluating the two machine learning classification models. Our first model is a support vector machine (SVM) which learns a linear decision rule based on the bag-of-ngrams representation of each title and abstract. Our second model is a convolutional neural network (CNN) which learns a complex nonlinear decision rule based on the raw title and abstract. We evaluated the performance of the two models on the classification of papers as relevant to penetrance or prevalence. RESULTS: For penetrance classification, we annotated 3740 paper titles and abstracts and used 60% for training the model, 20% for tuning the model, and 20% for evaluating the model. The SVM model achieves 89.53% accuracy (percentage of papers that were correctly classified) while the CNN model achieves 88.95 % accuracy. For prevalence classification, we annotated 3753 paper titles and abstracts. The SVM model achieves 89.14% accuracy while the CNN model achieves 89.13 % accuracy. CONCLUSION: Our models achieve high accuracy in classifying abstracts as relevant to penetrance or prevalence. By facilitating literature review, this tool could help clinicians and researchers keep abreast of the burgeoning knowledge of gene-cancer associations and keep the knowledge bases for clinical decision support tools up to date.
Attention-based models are successful when trained on large amounts of data. In this paper, we demonstrate that even in the low-resource scenario, attention can be learned effectively. To this end, we start with discrete human-annotated rationales and map them into continuous attention. Our central hypothesis is that this mapping is general across domains, and thus can be transferred from resource-rich domains to low-resource ones. Our model jointly learns a domain-invariant representation and induces the desired mapping between rationales and attention. Our empirical results validate this hypothesis and show that our approach delivers significant gains over state-of-the-art baselines, yielding over 15% average error reduction on benchmark datasets.