AugMask: Training Diffusion Models on Incomplete Tabular Data via Stochastic Augmentation and Masking
2026-06-02 • Machine Learning
Machine LearningArtificial Intelligence
AI summaryⓘ
The authors address the problem of using score-based diffusion models on tabular data that have missing values. They introduce AugMask, a training method that treats missing data differently by separating how the model conditions on inputs from how it learns to reconstruct data. AugMask uses a clever way to 'fill in' missing numbers with uncertain guesses but only teaches the model to predict what it actually sees, not the guesses. This approach helps standard diffusion models work better on incomplete data compared to other methods made specifically for missing values.
score-based diffusion modelstabular datamissing dataconditional augmentationdenoising supervisionRao-Blackwellizationvariance-weighted sensitivitydeep generative modelsmissing data imputationdata augmentation
Authors
Jungkyu Kim, Taeyoung Park, Kibok Lee
Abstract
Score-based diffusion models have emerged as prominent deep generative models; however, their application to tabular data remains challenging because their backbones assume fully specified inputs, whereas real-world tabular data often contain missing values. We propose AugMask, a plug-and-play training framework that adapts missing-unaware backbones to incomplete data by separating conditioning from supervision. AugMask 1) constructs numeric inputs via conditional stochastic augmentation using lightweight auxiliary models, and 2) applies denoising supervision only to observed coordinates. In effect, augmented missing entries serve as uncertain conditioning context rather than training targets. We connect this training rule to a Rao--Blackwellized objective and show that marginalizing missing entries yields a variance-weighted sensitivity penalty, discouraging over-reliance on uncertain completions. Across diverse datasets and missingness regimes, AugMask enables standard diffusion-based tabular generators to outperform specialized missing-aware baselines.