Adversarial Label Invariant Graph Data Augmentations for Out-of-Distribution Generalization
2026-04-09 • Machine Learning
Machine Learning
AI summaryⓘ
The authors address a problem called out-of-distribution (OoD) generalization, which happens when a model trained on some data doesn't do well on different but related data. They focus on a specific type of change called covariate shift, where only the input data changes but the underlying concepts stay the same. The authors propose a method named RIA that uses adversarial techniques to create new challenging training environments, helping the model learn to handle such shifts better. They show through experiments on graph classification tasks that their approach improves accuracy compared to other methods.
Out-of-distribution generalizationCovariate shiftRepresentation learningAdversarial trainingData augmentationQ-learningConstrained optimizationGradient descent-ascentGraph classification
Authors
Simon Zhang, Ryan P. DeMilt, Kun Jin, Cathy H. Xia
Abstract
Out-of-distribution (OoD) generalization occurs when representation learning encounters a distribution shift. This occurs frequently in practice when training and testing data come from different environments. Covariate shift is a type of distribution shift that occurs only in the input data, while the concept distribution stays invariant. We propose RIA - Regularization for Invariance with Adversarial training, a new method for OoD generalization under convariate shift. Motivated by an analogy to $Q$-learning, it performs an adversarial exploration for training data environments. These new environments are induced by adversarial label invariant data augmentations that prevent a collapse to an in-distribution trained learner. It works with many existing OoD generalization methods for covariate shift that can be formulated as constrained optimization problems. We develop an alternating gradient descent-ascent algorithm to solve the problem, and perform extensive experiments on OoD graph classification for various kinds of synthetic and natural distribution shifts. We demonstrate that our method can achieve high accuracy compared with OoD baselines.