CausalVAE as a Plug-in for World Models: Towards Reliable Counterfactual Dynamics

2026-04-09Machine Learning

Machine Learning
AI summary

The authors introduce CausalVAE, a module that can be added to existing models to help them understand cause-and-effect relationships better. When used, it keeps the model's ability to predict normal outcomes while making it much better at handling 'what if' scenarios where conditions change. This improvement is especially large on a physics-related test, where its performance more than doubled compared to other methods. They also show that the model learns meaningful physical relationships, making its decisions easier to interpret.

CausalVAElatent world modelsencoder-transition backbonescounterfactual retrievaldistribution shiftPhysics benchmarkCF-H@1causal analysisgraph neural networks (GNN)interpretability
Authors
Ziyi Ding, Xianxin Lai, Weiyu Chen, Xiao-Ping Zhang, Jiayu Chen
Abstract
In this work, CausalVAE is introduced as a plug-in structural module for latent world models and is attached to diverse encoder-transition backbones. Across the reported benchmarks, competitive factual prediction is preserved and intervention-aware counterfactual retrieval is improved after the plug-in is added, suggesting stronger robustness under distribution shift and interventions. The largest gains are observed on the Physics benchmark: when averaged over 8 paired baselines, CF-H@1 is improved by +102.5%. In a representative GNN-NLL setting on Physics, CF-H@1 is increased from 11.0 to 41.0 (+272.7%). Through causal analysis, learned structural dependencies are shown to recover meaningful first-order physical interaction trends, supporting the interpretability of the learned latent causal structure.