FiP: A Fixed-Point Approach for Causal Generative Modeling
- M. Scetbon ,
- Joel Jennings ,
- Agrin Hilmkil ,
- Cheng Zhang ,
- Chao Ma
ICML 2024 |
Modeling true world data-generating processes lies at the heart of empirical science. Structural Causal Models (SCMs) and their associated Directed Acyclic Graphs (DAGs) provide an increasingly popular answer to such problems by defining the causal generative process that transforms random noise into observations. However, learning them from observational data poses an ill-posed and NP-hard inverse problem in general. In this work, we propose a new and equivalent formalism that does not require DAGs to describe them, viewed as fixed-point problems on the causally ordered variables, and we show three important cases where they can be uniquely recovered given the topological ordering (TO). To the best of our knowledge, we obtain the weakest conditions for their recovery when TO is known. Based on this, we design a two-stage causal generative model that first infers the causal order from observations in a zero-shot manner, thus by-passing the search, and then learns the generative fixed-point SCM on the ordered variables. To infer TOs from observations, we propose to amortize the learning of TOs on generated datasets by sequentially predicting the leaves of graphs seen during training. To learn fixed-point SCMs, we design a transformer-based architecture that exploits a new attention mechanism enabling the modeling of causal structures, and show that this parameterization is consistent with our formalism. Finally, we conduct an extensive evaluation of each method individually, and show that when combined, our model outperforms various baselines on generated out-of-distribution problems.