causalflows.distributions

Causal Normalizing Flow distribution.

Classes

CausalNormalizingFlow

Class that extends zuko.distributions.NormalizingFlow with methods to compute interventions and counterfactuals.

Descriptions

class causalflows.distributions.CausalNormalizingFlow(transform, base)

Class that extends zuko.distributions.NormalizingFlow with methods to compute interventions and counterfactuals.

Parameters:
  • transform (Transform) – A transformation \(f\).

  • base (Distribution) – A base distribution \(p(Z)\).

See also

Example

>>> d = CausalNormalizingFlow(ExpTransform(), Gamma(2.0, 1.0))
>>> d.sample()
tensor(1.5157)

References

A Family of Non-parametric Density Estimation Algorithms (Tabak et al., 2013)
Variational Inference with Normalizing Flows (Rezende et al., 2015)
Normalizing Flows for Probabilistic Modeling and Inference (Papamakarios et al., 2021)
intervene(index, value)

Context manager that yields an interventional distribution.

Parameters:
  • index (LongTensor) – Index tensor of the intervened variables.

  • value (Tensor) – Values of the intervened variables.

Returns:

A CausalNormalizingFlow representing the interventional distribution.

Return type:

NormalizingFlow

Warning

Nested interventions have not yet been tested.

Example

>>> nflow = CausalNormalizingFlow(ExpTransform(), Gamma(2.0, torch.ones((1,))))
>>> with nflow.intervene(index=0, value=0.5) as int_nflow:
...   x = int_nflow.sample((3,))
>>> x
tensor([[0.5000],
        [0.5000],
        [0.5000]])
sample_interventional(index, value, sample_shape=())

Helper method to sample from an interventional distribution.

Parameters:
  • index (LongTensor) – Index tensor of the intervened variables.

  • value (Tensor) – Values of the intervened variables.

  • sample_shape (Size) – Batch shape of the samples.

Returns:

The intervened samples.

Return type:

Tensor

Example

>>> nflow = CausalNormalizingFlow(ExpTransform(), Gamma(2.0, torch.ones((2,))))
>>> x = nflow.sample_interventional(index=1, value=0.5, sample_shape=(3,))
>>> x
tensor([[ 1.5157,  0.5000],
        [-0.4748,  0.5000],
        [-0.1055,  0.5000]])
compute_counterfactual(factual, index, value)

Helper method to sample from a counterfactual distribution.

Parameters:
  • factual (Tensor) – The factual sample.

  • index (LongTensor) – Index tensor of the intervened variables.

  • value (Tensor) – Values of the intervened variables.

Returns:

The counterfactual samples, with identical shape as factual.

Return type:

Tensor

Example

>>> nflow = CausalNormalizingFlow(ExpTransform(), Gamma(2.0, torch.ones((2,))))
>>> factual = nflow.sample((3,))
>>> factual
tensor([[ 1.5157,  0.2745],
        [-0.4748, -0.8333],
        [-0.1055,  0.1809]])
>>> cfactual = nflow.compute_counterfactual(factual, index=0, value=0.5)
>>> cfactual
tensor([[ 0.5000,  0.2745],
        [ 0.5000, -0.8333],
        [ 0.5000,  0.1809]])