Visualising Causal Test Results

Overview

This tutorial provides users with minimal Python code to visualise the outputs of their causal test results in a simple but intuitive manner.

Note: this approach and provided code depends strongly on the structure and size of the user’s directed acyclic graph (DAG) and, therefore, an element of fine-tuning will be required.

[1]:
import pydot # you may need to run: !pip install pydot if necessary
import networkx as nx # !pip install networkx --upgrade
import matplotlib.pyplot as plt
import numpy as np
import json
import os
from matplotlib.lines import Line2D
import os
import json
import networkx as nx
import pydot
from pathlib import Path
from causal_testing.specification.causal_dag import CausalDAG
[2]:
def read_json_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

# Jupyter-friendly: start from current working directory
base_dir = Path().resolve()
directory = base_dir / ".." / "vaccinating_elderly"

causal_tests_results = read_json_file(directory / "causal_test_results.json")

dag = CausalDAG(str(directory / "dag.dot"))
[3]:
def draw_dag(dag, treatment_node, outcome_node, confounder_node, title, edge_colours, connection_style, ax):
    """
    Draw a Directed Acyclic Graph (DAG) with specified attributes.

    Parameters:
        dag (networkx.DiGraph): The DAG to be drawn.
        treatment_node (str): The treatment node in the DAG.
        outcome_node (str): The outcome node in the DAG.
        confounder_node (str): The confounder node in the DAG.
        title (str): Title of the plot.
        edge_colours (list): List of edge colors.
        connection_style (str): Style of connections between nodes.
        ax (matplotlib.axes.Axes): Axes to draw the plot on.
    """
    # Add layer information to nodes for multipartite layout
    for layer, nodes in enumerate(sorted(nx.topological_generations(dag))):
        for node in nodes:
            dag.nodes[node]["layer"] = layer

    # Compute multipartite layout using "layer" node attribute
    pos = nx.multipartite_layout(dag, subset_key="layer", align='horizontal')

    # Offset top layer nodes
    offset_amount = -0.5  # Adjust as needed
    for node, coords in pos.items():
        if dag.nodes[node]["layer"] == 0:  # Top layer
            pos[node] = (coords[0]+offset_amount, coords[1])

    color_map = ['C1' if node == treatment_node else 'C0' if node == outcome_node else 'C4' if node == confounder_node else 'white' for node in dag.nodes()]

    dag_edges = set(dag.edges())
    test_edges = {(treatment_node, outcome_node)}
    diff = test_edges - dag_edges

    dag_options = {
        "font_size": 10,
        "node_size": 10000,
        "node_color": color_map,
        "edgelist": dag.edges(),
        "edge_color": 'black',
        "edgecolors": 'black',
        "linewidths": 1.5,
        "width": 1.5,
        "style": '-',
    }

    test_options = {
        "node_size": 10000,
        "edgelist": diff if diff else test_edges,
        "edge_color": edge_colours,
        "width": 1.5,
        "style": '--' if diff else '-',
        "arrowsize": 12,
        "arrows": True,
        "connectionstyle" : connection_style,
    }

    nx.draw_networkx(dag, pos=pos, ax=ax,  **dag_options, with_labels=True)
    nx.draw_networkx_edges(dag, pos=pos, ax=ax,  **test_options)

    green_line = Line2D([], [], color='C2', ls ='--', label='Test Passed')
    red_line = Line2D([], [], color='C3', ls= '--', label='Test Failed')

    # Legend labels and colors
    legend_labels = ['Treatment', 'Outcome', 'Confounder']
    legend_colors = ['C0', 'C1', 'C4']

    # Create legend handles
    legend_handle_1 = [Line2D([0], [0], marker='o', color='w', markerfacecolor=color, markersize=10, label=label) for color, label in zip(legend_colors, legend_labels)]
    legend_handle_2 = [red_line, green_line]

    # Add Line2D objects to the legend
    ax.legend(handles=legend_handle_1+legend_handle_2, loc='best', bbox_to_anchor=(1.1, 1))
    ax.margins(0.40)
    ax.set_title(title)
    ax.axis("off")

def get_confounder_node(value):
    """
    Extracts the confounder node from a string value.

    Parameters:
        value (str): String containing node information.

    Returns:
        str: Confounder node extracted from the string.
    """
    if ' | ' in value and '[' in value and ']' in value:
        variable = value.split('[')[-1].split(']')[0].replace("'", "")
        return variable

fig, ax = plt.subplots(nrows=len(causal_tests_results), ncols=1, figsize=[10,100])

dag_copy = dag.copy()

for i, tests in enumerate(causal_tests_results):
    treatment_node  = tests['result']['treatment']
    outcome_node = tests['result']['outcome']
    confounder_node = get_confounder_node(tests['name'])
    title = tests['name']
    dag_copy.add_edge(treatment_node, outcome_node)
    edge_colours=[]

    for edge in dag_copy.edges():
        if edge == (treatment_node, outcome_node):
            if not tests['passed']:
                edge_colours.append("C3")
            elif tests['passed']:
                edge_colours.append("C2")
            else:
                edge_colours.append("C9")

    prefix = 'Causal test %s: ' %(i+1)

    if i==7:
        connection_style = "arc3,rad=-0.5"
        draw_dag(dag, treatment_node, outcome_node, confounder_node,  prefix + title, edge_colours, connection_style, ax[i])
    else:
        connection_style= "arc3,rad=0.0"
        draw_dag(dag, treatment_node, outcome_node, confounder_node,  prefix + title, edge_colours, connection_style, ax[i])

../../_images/tutorials_visualising_causal_test_results_visualise_causal_test_results_5_0.png

Additional Resources