causalflows.distributions¶
Causal Normalizing Flow distribution.
Classes¶
Class that extends |
Descriptions¶
- class causalflows.distributions.CausalNormalizingFlow(transform, base)¶
Class that extends
zuko.distributions.NormalizingFlowwith methods to compute interventions and counterfactuals.- Parameters:
transform (Transform) – A transformation \(f\).
base (Distribution) – A base distribution \(p(Z)\).
See also
NormalizingFlowThe equivalent non-causal counterpart from Zuko.
Example
>>> d = CausalNormalizingFlow(ExpTransform(), Gamma(2.0, 1.0)) >>> d.sample() tensor(1.5157)
References
A Family of Non-parametric Density Estimation Algorithms (Tabak et al., 2013)Variational Inference with Normalizing Flows (Rezende et al., 2015)Normalizing Flows for Probabilistic Modeling and Inference (Papamakarios et al., 2021)- intervene(index, value)¶
Context manager that yields an interventional distribution.
- Parameters:
index (LongTensor) – Index tensor of the intervened variables.
value (Tensor) – Values of the intervened variables.
- Returns:
A
CausalNormalizingFlowrepresenting the interventional distribution.- Return type:
Warning
Nested interventions have not yet been tested.
Example
>>> nflow = CausalNormalizingFlow(ExpTransform(), Gamma(2.0, torch.ones((1,)))) >>> with nflow.intervene(index=0, value=0.5) as int_nflow: ... x = int_nflow.sample((3,)) >>> x tensor([[0.5000], [0.5000], [0.5000]])
- sample_interventional(index, value, sample_shape=())¶
Helper method to sample from an interventional distribution.
- Parameters:
- Returns:
The intervened samples.
- Return type:
Example
>>> nflow = CausalNormalizingFlow(ExpTransform(), Gamma(2.0, torch.ones((2,)))) >>> x = nflow.sample_interventional(index=1, value=0.5, sample_shape=(3,)) >>> x tensor([[ 1.5157, 0.5000], [-0.4748, 0.5000], [-0.1055, 0.5000]])
- compute_counterfactual(factual, index, value)¶
Helper method to sample from a counterfactual distribution.
- Parameters:
- Returns:
The counterfactual samples, with identical shape as
factual.- Return type:
Example
>>> nflow = CausalNormalizingFlow(ExpTransform(), Gamma(2.0, torch.ones((2,)))) >>> factual = nflow.sample((3,)) >>> factual tensor([[ 1.5157, 0.2745], [-0.4748, -0.8333], [-0.1055, 0.1809]]) >>> cfactual = nflow.compute_counterfactual(factual, index=0, value=0.5) >>> cfactual tensor([[ 0.5000, 0.2745], [ 0.5000, -0.8333], [ 0.5000, 0.1809]])