2. A basic example of CNFs

This notebook walks you through a simple example that

  • defines its own synthetic SCM dataset,

  • define a causal normalizing flow,

  • trains the model,

  • evaluates the quality of the trained model.

For a tutorial on the basics of Pytorch distributions and zuko, see the previous tutorial.

import lightning as L
import torch
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

2.1. Create a dataset

We are going to manually create a simple linear SCM, \(x = Ax + u\) with \(u \sim \mathcal{N}(0, I)\) and \(A = \begin{pmatrix} 0 & 0 & 0 \\ 1 & 0 & 0 \\ 0.25 & 0.75 & 0 \end{pmatrix}\). Note that there exist predefined synthetic datasets in causalflows.scms.

from causalflows.scms import SCM, CausalEquations
class MyTriangleEqs(CausalEquations):
    def __init__(self):
        functions = [
            lambda u1: u1 + 0.0,
            lambda x1, u2: x1 + u2,
            lambda x1, x2, u3: 0.25 * x1 + 0.75 * x2 + u3,
        ]
        inverses = [
            lambda x1: x1 - 0.0,
            lambda x1, x2: x2 - x1,
            lambda x1, x2, x3: x3 - 0.75 * x2 - 0.25 * x1,
        ]

        super().__init__(functions, inverses)

    @property
    def adjacency(self):
        return torch.tensor((
            (1, 0, 0),
            (1, 1, 0),
            (1, 1, 1),
        )).bool()
scm = SCM(equations=MyTriangleEqs(), base='std-gaussian')


def generate_data(size: int):
    return scm.sample((size,))
adjacency = scm.adjacency
adjacency
tensor([[ True, False, False],
        [ True,  True, False],
        [ True,  True,  True]])

Generate datasets

from torch.utils.data import DataLoader, TensorDataset
batch_size = 2048

train_size = 10000
train_data = generate_data(train_size)

test_size = int(train_size * 0.1)
test_data = generate_data(test_size)

# Create the dataloaders
train_dataset = TensorDataset(train_data)
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)

test_dataset = TensorDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=test_size)
def create_df(tensors_dict):
    df_list = []
    for i, (k, x) in enumerate(tensors_dict.items()):
        df = pd.DataFrame(x.numpy(), columns=[f'$x_{j}$' for j in range(x.shape[1])])
        df['name'] = k
        df_list.append(df)

    df = pd.concat(df_list)
    return df


def plot_data(tensors_dict):
    df = create_df(tensors_dict)
    g = sns.PairGrid(df, diag_sharey=False, hue='name')
    g.map_upper(sns.scatterplot, s=15, alpha=0.5)
    g.map_lower(sns.kdeplot, common_norm=False)
    g.map_diag(sns.kdeplot, lw=2, common_norm=False)
    g.add_legend()
    plt.show()


plot_data({'Data': test_data})
../_images/d0db7035351d781ff574e441fa1c2aac5b6a65d4e8f20a7b94e95c3e9496fc28.png

2.2. Using a causal flow

from causalflows.flows import CausalMAF

features = 3
context = 0

flow = CausalMAF(features, context, adjacency=adjacency)
flow
CausalMAF(
  (transform): MaskedAutoregressiveTransform(
    (base): MonotonicAffineTransform()
    (passes): 3
    (hyper): MaskedMLP(
      (0): MaskedLinear(in_features=3, out_features=64, bias=True)
      (1): ReLU()
      (2): MaskedLinear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): MaskedLinear(in_features=64, out_features=6, bias=True)
    )
  )
  (base): UnconditionalDistribution(DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])))
)

Lightning boilerplate for logging

from lightning.pytorch.loggers.logger import Logger
from lightning.pytorch.utilities import rank_zero_only


class MyLogger(Logger):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.values = {}

    @property
    def name(self):
        return "MyLogger"

    @property
    def version(self):
        # Return the experiment version, int or str.
        return "0.1"

    @rank_zero_only
    def log_hyperparams(self, params):
        # params is an argparse.Namespace
        # your code to record hyperparameters goes here
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # metrics is a dictionary of metric names and values
        # your code to record metrics goes here
        for k, v in metrics.items():
            if k in self.values:
                self.values[k].append(v)
            else:
                self.values[k] = [v]

    @rank_zero_only
    def save(self):
        # Optional. Any code necessary to save logger data goes here
        pass

    @rank_zero_only
    def finalize(self, status):
        # Optional. Any code that needs to be run after training
        # finishes goes here
        pass

Define the LightningModule to train the model

class LitFlow(L.LightningModule):
    def __init__(self, flow):
        super().__init__()
        self.flow = flow

    def training_step(self, batch, batch_idx):
        x = batch[0]
        loss = -self.flow().log_prob(x).mean()
        self.log('train_loss', loss.detach())
        self.log('log_prob', -loss.detach())
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
model = LitFlow(flow)
logger = MyLogger()
trainer = L.Trainer(
    max_epochs=200, logger=logger, enable_checkpointing=False, log_every_n_steps=len(train_loader)
)
trainer.fit(model=model, train_dataloaders=train_loader)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name | Type      | Params | Mode 
-------------------------------------------
0 | flow | CausalMAF | 4.8 K  | train
-------------------------------------------
4.8 K     Trainable params
0         Non-trainable params
4.8 K     Total params
0.019     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode
/home/ajavaloy/Workspace/libraries/causal-flows/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=200` reached.
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
sns.lineplot(x=logger.values['epoch'], y=logger.values['train_loss'], ax=axs[0])
sns.lineplot(x=logger.values['epoch'], y=logger.values['log_prob'], ax=axs[1])
plt.show()
../_images/47be278d89ab96a6bf00de7c907d8e7631953d6719177ac5d4e570bf76f155bc.png
gen_data = flow().sample([test_size])
gen_data.shape
torch.Size([1000, 3])
plot_data({'og': test_data, 'flow': gen_data})
../_images/ee31b1fad721bf5f9e9a5a0e21db0a1c6584e51ddce91da1f9b0e8e75c8d2eab.png

2.3. Generate interventional data

def generate_int_data(size: int, index, value):
    return scm.sample_interventional(index, value, (size,))
true_int_data = generate_int_data(test_size, 1, 2.5)

with flow().intervene(index=1, value=2.5) as int_flow:
    int_data = int_flow.sample((test_size,))

# We could also sample with the helper method
# int_data = flow().sample_interventional(index=1, value=2.5, sample_shape=(test_size,))

int_data += torch.randn((test_size, 3)) * 0.01
true_int_data += torch.randn((test_size, 3)) * 0.01

plot_data({
    r'$\text{flow}_\text{obs}$': gen_data,
    r'$\text{flow}_{\mathrm{do}(x_1 = 2.5)}$': int_data,
    r'$\text{SCM}_{\mathrm{do}(x_1 = 2.5)}$': true_int_data,
})
tensor([1])
../_images/31919461b436c98e4bcec297935c23e4f646046efed7f9eb8d2285a28bee77af.png