Learning What Makes a Difference from Counterfactual Examples and Gradient Supervision
One of the primary challenges limiting the practical application of deep learning is its susceptibility to learning spurious correlations in the data, rather than capturing the data-generating mechanisms of the task of interest. The resulting failure to generalise cannot be addressed by simply using more data from the same distribution. We propose an auxiliary training objective that improves the generalization capabilities of neural networks by leveraging an overlooked supervisory signal found in existing datasets. We demonstrate that pairs of minimally-different examples with different labels, a.k.a counterfactual examples, provide a signal indicative of the underlying causal structure of the task. We show that such pairs can be identified in a number of existing datasets for a range of tasks in vision (visual question answering, multi-label image classification) and natural language processing (sentiment analysis, natural language inference). We propose a training objective that orients the gradient on a model's decision boundary to align with pairwise relations in the input domain. Models trained with this technique demonstrate improved performance on out-of-distribution test sets."