From d7af946b93d00d356752ec8e613e1ecee92760bc Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Thu, 11 Dec 2025 23:13:36 +0100 Subject: [PATCH 1/2] Add CausalPy skills and command documentation Introduced new skill files for designing experiments, loading datasets, performing causal analysis, and working with marimo notebooks under .claude/skills/. Added detailed reference documentation for DiD, ITS, and Synthetic Control methods. Added .cursor/commands/ files covering CausalPy demos, estimators, extras (including PlaceboAnalysis), marimo usage, methods, and research guidance. Updated .gitignore and pyproject.toml to support these additions. --- .claude/skills/designing-experiments/SKILL.md | 31 ++ .claude/skills/loading-datasets/SKILL.md | 30 ++ .../performing-causal-analysis/SKILL.md | 28 ++ .../reference/diff_in_diff.md | 50 ++ .../reference/interrupted_time_series.md | 51 ++ .../reference/synthetic_control.md | 49 ++ .../skills/running-placebo-analysis/SKILL.md | 25 + .../reference/placebo_in_time.md | 292 +++++++++++ .claude/skills/working-with-marimo/SKILL.md | 34 ++ .../reference/best_practices.md | 37 ++ .cursor/commands/causalpy_demos.md | 49 ++ .cursor/commands/causalpy_estimators.md | 40 ++ .cursor/commands/causalpy_extras.md | 472 ++++++++++++++++++ .cursor/commands/causalpy_marimo.md | 308 ++++++++++++ .cursor/commands/causalpy_methods.md | 189 +++++++ .cursor/commands/causalpy_research.md | 47 ++ .cursor/commands/commit.md | 40 ++ .cursor/commands/implement.md | 6 + .cursor/commands/make_plan.md | 12 + .cursor/commands/research.md | 34 ++ .cursor/rules/basic.mdc | 8 + .gitignore | 3 +- environment.yml | 1 + pyproject.toml | 3 + 24 files changed, 1838 insertions(+), 1 deletion(-) create mode 100644 .claude/skills/designing-experiments/SKILL.md create mode 100644 .claude/skills/loading-datasets/SKILL.md create mode 100644 .claude/skills/performing-causal-analysis/SKILL.md create mode 100644 .claude/skills/performing-causal-analysis/reference/diff_in_diff.md create mode 100644 .claude/skills/performing-causal-analysis/reference/interrupted_time_series.md create mode 100644 .claude/skills/performing-causal-analysis/reference/synthetic_control.md create mode 100644 .claude/skills/running-placebo-analysis/SKILL.md create mode 100644 .claude/skills/running-placebo-analysis/reference/placebo_in_time.md create mode 100644 .claude/skills/working-with-marimo/SKILL.md create mode 100644 .claude/skills/working-with-marimo/reference/best_practices.md create mode 100644 .cursor/commands/causalpy_demos.md create mode 100644 .cursor/commands/causalpy_estimators.md create mode 100644 .cursor/commands/causalpy_extras.md create mode 100644 .cursor/commands/causalpy_marimo.md create mode 100644 .cursor/commands/causalpy_methods.md create mode 100644 .cursor/commands/causalpy_research.md create mode 100644 .cursor/commands/commit.md create mode 100644 .cursor/commands/implement.md create mode 100644 .cursor/commands/make_plan.md create mode 100644 .cursor/commands/research.md create mode 100644 .cursor/rules/basic.mdc diff --git a/.claude/skills/designing-experiments/SKILL.md b/.claude/skills/designing-experiments/SKILL.md new file mode 100644 index 00000000..4d69ae1f --- /dev/null +++ b/.claude/skills/designing-experiments/SKILL.md @@ -0,0 +1,31 @@ +--- +name: designing-experiments +description: Selects the appropriate quasi-experimental method (DiD, ITS, SC) based on data structure and research questions. Use when the user is unsure which method to apply. +--- + +# Designing Experiments + +Helps select the appropriate causal inference method. + +## Decision Framework + +1. **Control Group?** + * **Yes**: Go to Step 2. + * **No**: Consider **Interrupted Time Series (ITS)**. + +2. **Unit Structure?** + * **Single Treated Unit**: + * With multiple controls: **Synthetic Control (SC)**. + * No controls: **ITS**. + * **Multiple Treated Units**: + * With control group: **Difference-in-Differences (DiD)**. + +3. **Time Structure?** + * **Panel Data** (Multiple units over time): Required for DiD and SC. + * **Time Series** (Single unit over time): Required for ITS. + +## Method Quick Reference + +* **Difference-in-Differences (DiD)**: Compares trend changes between treated and control groups. Assumes **Parallel Trends**. +* **Interrupted Time Series (ITS)**: Analyzes trend/level change for a single unit after intervention. Assumes **Trend Continuity**. +* **Synthetic Control (SC)**: Constructs a synthetic counterfactual from weighted control units. Assumes **Convex Hull** (treated unit within range of controls). diff --git a/.claude/skills/loading-datasets/SKILL.md b/.claude/skills/loading-datasets/SKILL.md new file mode 100644 index 00000000..02eb33d7 --- /dev/null +++ b/.claude/skills/loading-datasets/SKILL.md @@ -0,0 +1,30 @@ +--- +name: loading-datasets +description: Loads internal CausalPy example datasets. Use when the user needs example data or asks about available demos. +--- + +# Loading Datasets + +Loads example datasets provided with CausalPy. + +## Usage + +```python +import causalpy as cp +df = cp.load_data("dataset_name") +``` + +## Available Datasets + +| Key | Description | +| :--- | :--- | +| `did` | Generic Difference-in-Differences | +| `its` | Generic Interrupted Time Series | +| `sc` | Generic Synthetic Control | +| `banks` | DiD (Banks) | +| `brexit` | Synthetic Control (Brexit) | +| `covid` | ITS (Covid) | +| `drinking` | Regression Discontinuity (Drinking Age) | +| `rd` | Generic Regression Discontinuity | +| `geolift1` | GeoLift (Single cell) | +| `geolift_multi_cell` | GeoLift (Multi cell) | diff --git a/.claude/skills/performing-causal-analysis/SKILL.md b/.claude/skills/performing-causal-analysis/SKILL.md new file mode 100644 index 00000000..c81c776b --- /dev/null +++ b/.claude/skills/performing-causal-analysis/SKILL.md @@ -0,0 +1,28 @@ +--- +name: performing-causal-analysis +description: Fits causal models, estimates impacts, and plots results using CausalPy. Use when performing analysis with DiD, ITS, SC, or RD. +--- + +# Performing Causal Analysis + +Executes causal analysis using CausalPy experiment classes. + +## Workflow + +1. **Load Data**: Ensure data is in a Pandas DataFrame. +2. **Initialize Experiment**: Use the appropriate class (see References). +3. **Fit & Model**: Models are fitted automatically upon initialization if arguments are provided. +4. **Analyze Results**: Use `summary()`, `print_coefficients()`, and `plot()`. + +## Core Methods + +* `experiment.summary()`: Prints model summary and main results. +* `experiment.plot()`: Visualizes observed vs. counterfactual. +* `experiment.print_coefficients()`: Shows model coefficients. + +## References + +Detailed usage for specific methods: +* [Difference-in-Differences](reference/diff_in_diff.md) +* [Interrupted Time Series](reference/interrupted_time_series.md) +* [Synthetic Control](reference/synthetic_control.md) diff --git a/.claude/skills/performing-causal-analysis/reference/diff_in_diff.md b/.claude/skills/performing-causal-analysis/reference/diff_in_diff.md new file mode 100644 index 00000000..a5721355 --- /dev/null +++ b/.claude/skills/performing-causal-analysis/reference/diff_in_diff.md @@ -0,0 +1,50 @@ +# Causal Difference-in-Differences (DiD) + +Difference-in-Differences (DiD) estimates the causal effect of a treatment by comparing the changes in outcomes over time between a treatment group and a control group. + +## Class: `DifferenceInDifferences` + +```python +causalpy.experiments.DifferenceInDifferences( + data, + formula, + time_variable_name, + group_variable_name, + post_treatment_variable_name="post_treatment", + model=None, + **kwargs +) +``` + +### Parameters +* **`data`** (`pd.DataFrame`): Input dataframe containing panel data. +* **`formula`** (`str`): Statistical formula (e.g., `"y ~ 1 + group * post_treatment"`). +* **`time_variable_name`** (`str`): Column name for the time variable. +* **`group_variable_name`** (`str`): Column name for the group indicator (0=Control, 1=Treated). **Must be dummy coded**. +* **`post_treatment_variable_name`** (`str`): Column name indicating the post-treatment period (0=Pre, 1=Post). Default is `"post_treatment"`. +* **`model`**: A PyMC model (e.g., `cp.pymc_models.LinearRegression`) or a Scikit-Learn Regressor. + +### How it Works +1. **Fit**: The model fits all available data (pre/post, treatment/control). +2. **Counterfactual**: Predicted by setting the interaction term between `group` and `post_treatment` to 0. +3. **Impact**: The causal impact is the difference between observed and counterfactual. + +### Example + +```python +import causalpy as cp +import causalpy.pymc_models as cp_pymc + +df = cp.load_data("did") + +result = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + model=cp_pymc.LinearRegression(sample_kwargs={"target_accept": 0.9}) +) + +result.summary() +result.plot() +``` diff --git a/.claude/skills/performing-causal-analysis/reference/interrupted_time_series.md b/.claude/skills/performing-causal-analysis/reference/interrupted_time_series.md new file mode 100644 index 00000000..b2d95133 --- /dev/null +++ b/.claude/skills/performing-causal-analysis/reference/interrupted_time_series.md @@ -0,0 +1,51 @@ +# Causal Interrupted Time Series (ITS) + +Interrupted Time Series (ITS) analyzes the effect of an intervention on a single time series by comparing the trend before and after the intervention. + +## Class: `InterruptedTimeSeries` + +```python +causalpy.experiments.InterruptedTimeSeries( + data, + treatment_time, + formula, + model=None, + **kwargs +) +``` + +### Parameters +* **`data`** (`pd.DataFrame`): Input dataframe. Index should ideally be a `pd.DatetimeIndex`. +* **`treatment_time`** (`Union[int, float, pd.Timestamp]`): The point in time when the intervention occurred. +* **`formula`** (`str`): Statistical formula (e.g., `"y ~ 1 + t + C(month)"`). +* **`model`**: A PyMC model (e.g., `cp.pymc_models.LinearRegression`) or a Scikit-Learn Regressor. + +### How it Works +1. **Split**: Data is split into pre- and post-intervention. +2. **Fit**: Model is trained **only on pre-intervention data**. +3. **Predict**: Fitted model predicts the outcome for the post-intervention period. +4. **Impact**: Difference between observed post-intervention data and counterfactual predictions. + +### Example + +```python +import causalpy as cp +import causalpy.pymc_models as cp_pymc +import pandas as pd + +df = cp.load_data("its") +df["date"] = pd.to_datetime(df["date"]) +df.set_index("date", inplace=True) + +treatment_time = pd.to_datetime("2017-01-01") + +result = cp.InterruptedTimeSeries( + df, + treatment_time, + formula="y ~ 1 + t + C(month)", + model=cp_pymc.LinearRegression() +) + +result.summary() +result.plot() +``` diff --git a/.claude/skills/performing-causal-analysis/reference/synthetic_control.md b/.claude/skills/performing-causal-analysis/reference/synthetic_control.md new file mode 100644 index 00000000..daf819c4 --- /dev/null +++ b/.claude/skills/performing-causal-analysis/reference/synthetic_control.md @@ -0,0 +1,49 @@ +# Causal Synthetic Control (SCG) + +Synthetic Control constructs a "synthetic" counterfactual unit using a weighted combination of untreated control units. + +## Class: `SyntheticControl` + +```python +causalpy.experiments.SyntheticControl( + data, + treatment_time, + control_units, + treated_units, + model=None, + **kwargs +) +``` + +### Parameters +* **`data`** (`pd.DataFrame`): Input dataframe containing panel data. +* **`treatment_time`** (`Union[int, float, pd.Timestamp]`): The time of intervention. +* **`control_units`** (`List[str]`): List of column names representing the control units. +* **`treated_units`** (`List[str]`): List of column names representing the treated unit(s). +* **`model`**: A PyMC model (typically `cp.pymc_models.WeightedSumFitter`) or a Scikit-Learn Regressor. + +### How it Works +1. **Fit**: Model learns weights for `control_units` to approximate `treated_units` using **only pre-intervention data**. +2. **Predict**: Weights are applied to `control_units` in post-intervention period. +3. **Impact**: Difference between observed treated unit and synthetic counterfactual. + +### Example + +```python +import causalpy as cp +import causalpy.pymc_models as cp_pymc + +df = cp.load_data("sc") +treatment_time = 70 + +result = cp.SyntheticControl( + df, + treatment_time, + control_units=["a", "b", "c", "d", "e"], + treated_units=["actual"], + model=cp_pymc.WeightedSumFitter() +) + +result.summary() +result.plot() +``` diff --git a/.claude/skills/running-placebo-analysis/SKILL.md b/.claude/skills/running-placebo-analysis/SKILL.md new file mode 100644 index 00000000..0bc1a23d --- /dev/null +++ b/.claude/skills/running-placebo-analysis/SKILL.md @@ -0,0 +1,25 @@ +--- +name: running-placebo-analysis +description: Performs placebo-in-time sensitivity analysis to validate causal claims. Use when checking model robustness, verifying lack of pre-intervention effects, or ensuring observed effects are not spurious. +--- + +# Running Placebo Analysis + +Executes placebo-in-time sensitivity analysis to validate causal experiments. + +## Workflow + +1. **Define Experiment Factory**: Create a function that returns a fitted CausalPy experiment (e.g., ITS, DiD, SC) given a dataset and time boundaries. +2. **Configure Analysis**: Initialize `PlaceboAnalysis` with the factory, dataset, intervention dates, and number of folds (cuts). +3. **Run Analysis**: Execute `.run()` to fit models on pre-intervention data folds. +4. **Evaluate Results**: Compare placebo effects (which should be null) to the actual intervention effect. Use histograms and hierarchical models to quantify the "status quo" distribution. + +## Key Concepts + +* **Placebo-in-time**: Simulating an intervention at a time when none occurred to check if the model falsely detects an effect. +* **Fold**: A slice of pre-intervention data used to test a placebo period. +* **Factory Pattern**: Decouples the placebo logic from the specific CausalPy experiment type. + +## References + +* [Placebo-in-time Implementation](reference/placebo_in_time.md): Full code for the `PlaceboAnalysis` class, usage examples, and hierarchical status-quo modeling. diff --git a/.claude/skills/running-placebo-analysis/reference/placebo_in_time.md b/.claude/skills/running-placebo-analysis/reference/placebo_in_time.md new file mode 100644 index 00000000..faf88e78 --- /dev/null +++ b/.claude/skills/running-placebo-analysis/reference/placebo_in_time.md @@ -0,0 +1,292 @@ +# Placebo-in-time Analysis + +## Overview + +The `PlaceboAnalysis` class implements a placebo-in-time sensitivity analysis for causal inference experiments. This technique helps validate causal claims by testing whether the intervention effect appears in periods where no intervention actually occurred. + +## When to Use + +Use `PlaceboAnalysis` when you want to: + +1. **Validate causal claims**: Test if your model would detect spurious effects in pre-intervention periods where no treatment occurred +2. **Check model specification**: Verify that your model isn't picking up pre-existing trends or patterns that could be mistaken for treatment effects +3. **Assess robustness**: Demonstrate that the observed effect is specific to the actual intervention period and not a general pattern in the data +4. **Strengthen inference**: Provide additional evidence that the treatment effect is real by showing no effects in placebo periods + +## Implementation + +Since this class is not yet part of the core library, you must define it in your code: + +```python +from typing import Any, Callable +from pydantic import BaseModel +import logging +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import xarray as xr +import causalpy as cp +import pymc as pm +import arviz as az + +logger = logging.getLogger(__name__) + +class PlaceboAnalysis(BaseModel): + """ + Run sensitivity analysis for any causalpy experiment using a factory pattern. + + The factory function allows complete flexibility in choosing and configuring + any causalpy experiment type (SyntheticControl, InterruptedTimeSeries, + DifferenceInDifferences, RegressionDiscontinuity, etc.). + + Parameters + ---------- + experiment_factory : Callable + A function that creates and returns a fitted causalpy experiment. + Signature: (dataset: pd.DataFrame, treatment_time: pd.Timestamp, + treatment_time_end: pd.Timestamp) -> causalpy result + dataset : pd.DataFrame + The full dataset with datetime index + intervention_start_date : str + Start date of the intervention period (YYYY-MM-DD format) + intervention_end_date : str + End date of the intervention period (YYYY-MM-DD format) + n_cuts : int + Number of cuts for cross-validation (n_cuts - 1 folds will be created) + """ + + model_config = {"arbitrary_types_allowed": True} + + experiment_factory: Callable + dataset: pd.DataFrame + intervention_start_date: str + intervention_end_date: str + n_cuts: int = 2 + + def _validate_cuts(self, n_cuts: int) -> None: + """Validate that n_cuts is at least 2.""" + if n_cuts < 2: + raise ValueError("n_cuts must be >= 2 (n_cuts - 1 folds will be created).") + + def _prepare_pre_data(self, treatment_time: pd.Timestamp) -> pd.DataFrame: + """Extract pre-intervention data.""" + pre_df = self.dataset.loc[self.dataset.index < treatment_time].copy() + if pre_df.empty: + raise ValueError("No observations strictly before treatment_time in dataset.") + return pre_df + + def _calculate_intervention_length( + self, treatment_time: pd.Timestamp, treatment_time_end: pd.Timestamp + ) -> pd.Timedelta: + """Calculate the length of the intervention period.""" + treatment_time = pd.Timestamp(treatment_time) + treatment_time_end = pd.Timestamp(treatment_time_end) + intervention_length = treatment_time_end - treatment_time + if intervention_length <= pd.Timedelta(0): + raise ValueError("treatment_time_end must be after treatment_time to compute a positive intervention length.") + return intervention_length + + def _validate_sufficient_data( + self, + pre_df: pd.DataFrame, + treatment_time: pd.Timestamp, + intervention_length: pd.Timedelta, + n_cuts: int, + ) -> None: + """Validate that there's sufficient pre-intervention data for the requested folds.""" + pre_start = pre_df.index.min() + earliest_needed = treatment_time - (n_cuts - 1) * intervention_length + if pre_start > earliest_needed: + max_cuts = 1 + int((treatment_time - pre_start) // intervention_length) + raise ValueError( + "Not enough pre-period for requested folds. " + f"Earliest required: {earliest_needed.date()}, available starts: {pre_start.date()}. " + f"Try n_cuts <= {max_cuts}." + ) + + def _create_fold_data( + self, + pre_df: pd.DataFrame, + fold: int, + n_cuts: int, + treatment_time: pd.Timestamp, + intervention_length: pd.Timedelta, + ) -> tuple[pd.DataFrame, pd.Timestamp, pd.Timestamp]: + """Create data for a specific fold.""" + pseudo_start = treatment_time - (n_cuts - fold) * intervention_length + pseudo_end = pseudo_start + intervention_length + fold_df = pre_df.loc[pre_df.index < pseudo_end].sort_index() + + pre_mask = fold_df.index < pseudo_start + post_mask = (fold_df.index >= pseudo_start) & (fold_df.index < pseudo_end) + + if pre_mask.sum() == 0 or post_mask.sum() == 0: + raise ValueError( + f"Fold {fold}: insufficient data. pre_n={pre_mask.sum()}, post_n={post_mask.sum()} " + f"for window [{pseudo_start} .. {pseudo_end})." + ) + + return fold_df, pseudo_start, pseudo_end + + def _fit_model( + self, fold_df: pd.DataFrame, pseudo_start: pd.Timestamp, pseudo_end: pd.Timestamp + ) -> Any: + """ + Fit the experiment using the provided factory function. + """ + logger.info(f"Fitting model for fold with treatment_time={pseudo_start}, treatment_time_end={pseudo_end}") + return self.experiment_factory(fold_df, pseudo_start, pseudo_end) + + def run(self) -> list[dict[str, Any]]: + """ + Run the sensitivity analysis across all folds. + """ + n_cuts = self.n_cuts + treatment_time = pd.Timestamp(self.intervention_start_date) + treatment_time_end = pd.Timestamp(self.intervention_end_date) + + self._validate_cuts(n_cuts) + pre_df = self._prepare_pre_data(treatment_time) + intervention_length = self._calculate_intervention_length(treatment_time, treatment_time_end) + self._validate_sufficient_data(pre_df, treatment_time, intervention_length, n_cuts) + + results: list[dict[str, Any]] = [] + for fold in range(1, n_cuts): + fold_df, pseudo_start, pseudo_end = self._create_fold_data( + pre_df, fold, n_cuts, treatment_time, intervention_length + ) + + model_result = self._fit_model(fold_df, pseudo_start, pseudo_end) + + results.append( + { + "fold": fold, + "pseudo_start": pseudo_start, + "pseudo_end": pseudo_end, + "result": model_result, + } + ) + + return results +``` + +## Example Usage + +```python +# 1. Define a factory function +def its_factory(dataset, treatment_time, treatment_time_end): + formula = "target ~ 1 + feature" + return cp.InterruptedTimeSeries( + dataset, + treatment_time, + formula=formula, + model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": 42}) + ) + +# 2. Run sensitivity analysis +sensitivity = PlaceboAnalysis( + experiment_factory=its_factory, + dataset=df.set_index("date"), + intervention_start_date="2024-01-01", + intervention_end_date="2024-01-30", + n_cuts=4 +) +results = sensitivity.run() + +# 3. Plot results +for r in results: + r["result"].plot() +plt.show() +``` + +## Visualization: Posterior Cumulative Distribution + +```python +# Extract and stack post-impact samples +sensitivity_post_impact = xr.concat( + [ + r["result"] + .post_impact.sum("obs_ind") # sum over days in pseudo window + .isel(treated_units=0) + .stack(sample=("chain", "draw")) + for r in results + ], + dim="fold", +) + +# Convert sample coordinate for plotting +sensitivity_post_impact_numeric = sensitivity_post_impact.assign_coords( + sample=np.arange(len(sensitivity_post_impact.sample)) +) + +# Plot histograms +n_folds = sensitivity_post_impact.sizes["fold"] +fold_means = [sensitivity_post_impact_numeric.isel(fold=i).mean().item() for i in range(n_folds)] + +fig, ax = plt.subplots(1, 1, figsize=(8, 6)) +for i in range(n_folds): + fold_data = sensitivity_post_impact_numeric.isel(fold=i) + fold_data.plot.hist(ax=ax, alpha=0.7, label=f'Fold {i} (mean: {fold_means[i]:.1f})') +ax.legend() +plt.show() +``` + +## Advanced: Hierarchical Status Quo Model + +Build a hierarchical Bayesian model to characterize the "status quo" distribution of placebo effects. + +```python +# 1. Extract summaries +n_folds = sensitivity_post_impact.sizes["fold"] +n_samples = sensitivity_post_impact.sizes["sample"] +fold_means = sensitivity_post_impact.mean(dim="sample").values +fold_sds = sensitivity_post_impact.std(dim="sample").values +fold_sds = np.where(fold_sds < 1e-6, 1e-6, fold_sds) # Guard against degenerate SDs + +coords_meta = {"fold": np.arange(n_folds)} +prior_mu_center = float(np.nanmean(fold_means)) +prior_mu_scale = float(np.nanstd(fold_means)) if np.nanstd(fold_means) > 0.0 else 1.0 + +# 2. Define and fit model +n_chains = 4 +draws_per_chain_meta = n_samples // n_chains + +with pm.Model(coords=coords_meta) as meta_status_quo_model: + observed_fold_means = pm.Data("observed_fold_means", fold_means, dims="fold") + observed_fold_sd = pm.Data("observed_fold_sd", fold_sds, dims="fold") + + mu_status_quo = pm.Normal("mu_status_quo", mu=prior_mu_center, sigma=5.0 * prior_mu_scale) + tau_status_quo = pm.HalfNormal("tau_status_quo", sigma=2.0 * prior_mu_scale) + fold_standard_normal = pm.Normal("fold_standard_normal", mu=0.0, sigma=1.0, dims="fold") + + fold_true_total_effect = pm.Deterministic( + "fold_true_total_effect", + mu_status_quo + tau_status_quo * fold_standard_normal, + dims="fold", + ) + + likelihood_fold_means = pm.Normal( + "likelihood_fold_means", + mu=fold_true_total_effect, + sigma=observed_fold_sd, + observed=observed_fold_means, + dims="fold", + ) + + idata_meta_status_quo = pm.sample( + draws=draws_per_chain_meta, + chains=n_chains, + target_accept=0.97, + ) + +# 3. Posterior predictive for new period +with meta_status_quo_model: + meta_status_quo_model.add_coords({"new_period": np.arange(1)}) + theta_new = pm.Normal("theta_new", mu=mu_status_quo, sigma=tau_status_quo, dims="new_period") + posterior_predictive_status_quo = pm.sample_posterior_predictive(idata_meta_status_quo, var_names=["theta_new"]) + +# Plot result +theta_new_samples = posterior_predictive_status_quo["posterior_predictive"]["theta_new"].stack(sample=("chain", "draw")).values.squeeze() +plt.hist(theta_new_samples, bins=40, density=True, alpha=0.6, label="θ_new ~ N(μ, τ)") +plt.show() +``` diff --git a/.claude/skills/working-with-marimo/SKILL.md b/.claude/skills/working-with-marimo/SKILL.md new file mode 100644 index 00000000..8ff6c8a7 --- /dev/null +++ b/.claude/skills/working-with-marimo/SKILL.md @@ -0,0 +1,34 @@ +--- +name: working-with-marimo +description: Interactive development in marimo notebooks with validation loops. Use for creating/editing marimo notebooks and verifying execution. +--- + +# Working with Marimo + +Follows a **Plan-Execute-Verify** loop to ensure notebook correctness. + +## Feedback Loop + +1. **Context & Plan**: + * **Sessions**: `mcp_marimo_get_active_notebooks` (Find session IDs). + * **Structure**: `mcp_marimo_get_lightweight_cell_map` (See cell IDs/content). + * **Data State**: `mcp_marimo_get_tables_and_variables` (Inspect DataFrames/Variables). + * **Cell Detail**: `mcp_marimo_get_cell_runtime_data` (Code, errors, local vars). + +2. **Execute**: + * Edit the `.py` file directly using `write` or `search_replace`. + * **Rule**: Follow [Best Practices](reference/best_practices.md) (e.g., `@app.cell`, no global state). + +3. **Verify (CRITICAL)**: + * **Lint**: `mcp_marimo_lint_notebook` (Static analysis). + * **Runtime Errors**: `mcp_marimo_get_notebook_errors` (Execution errors). + * **Outputs**: `mcp_marimo_get_cell_outputs` (Visuals/Console). + +## Common Commands + +* **Start/Sync**: Marimo automatically syncs file changes. +* **SQL**: Use `mo.sql` for DuckDB queries. +* **Plots**: Use `plt.gca()` or return figure. **No `plt.show()`**. + +## Reference +See [Best Practices](reference/best_practices.md) for code formatting, reactivity rules, and UI element usage. diff --git a/.claude/skills/working-with-marimo/reference/best_practices.md b/.claude/skills/working-with-marimo/reference/best_practices.md new file mode 100644 index 00000000..34cecb4f --- /dev/null +++ b/.claude/skills/working-with-marimo/reference/best_practices.md @@ -0,0 +1,37 @@ +# Marimo Best Practices + +## CausalPy Specifics +* **Data**: Use **Pandas** (standard for CausalPy), even though generic marimo docs suggest Polars. +* **Plotting**: Use **Matplotlib**, **Seaborn**, or **Arviz**. **Avoid Altair**. +* **Display**: Return the figure/axis object or use `plt.gca()` as the last expression. **Do NOT use `plt.show()`**. + +## Code Structure +* **Decorators**: Every cell must start with `@app.cell` and define a function `def _():`. +* **Imports**: Put all imports in one cell (usually the first). Always import `marimo as mo`. +* **No Globals**: Variables are local to cells unless returned; marimo handles state passing. +* **Reactivity**: Cells run automatically when inputs change. **Avoid cycles**. + +## Visualizations & Outputs +* **Separation**: Do NOT mix `mo.md` and plots in the same cell. Create a markdown cell, then a plot cell. +* **Last Expression**: The last line of a cell is automatically displayed. + +## Data & SQL +* **DuckDB**: Use `df = mo.sql(f"""SELECT ...""")`. +* **Comments**: Do NOT put comments inside `mo.sql` strings or Markdown cells. + +## UI Elements +* **Access**: Use `.value` (e.g., `slider.value`). +* **Definition**: Define UI element in one cell, access value in another to avoid cycles. + +## Example Cell +```python +@app.cell +def _(): + import marimo as mo + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + ax.plot([1, 2, 3]) + ax # Display as last expression + return fig, ax, mo +``` diff --git a/.cursor/commands/causalpy_demos.md b/.cursor/commands/causalpy_demos.md new file mode 100644 index 00000000..dc36d787 --- /dev/null +++ b/.cursor/commands/causalpy_demos.md @@ -0,0 +1,49 @@ +# Causal Demos + +This skill handles the retrieval and loading of example datasets provided within the `CausalPy` library. + +## Loading Data + +To load internal datasets, use the `load_data` function from the `causalpy` module. + +```python +import causalpy as cp + +# Load a specific dataset +df = cp.load_data("dataset_name") +``` + +## Available Datasets + +The following datasets are available for demonstration purposes. Use the key (left) to load the corresponding file. + +| Key | Description / Context | +| :--- | :--- | +| `did` | Generic Difference-in-Differences example | +| `its` | Generic Interrupted Time Series example | +| `sc` | Generic Synthetic Control example | +| `banks` | DiD example (Banks) | +| `brexit` | Synthetic Control example (GDP impact of Brexit) | +| `covid` | ITS example (Deaths and temps in England/Wales) | +| `drinking` | Regression Discontinuity example (Minimum Legal Drinking Age) | +| `rd` | Generic Regression Discontinuity example | +| `its simple` | Simple ITS example | +| `anova1` | Generated ANCOVA example | +| `geolift1` | GeoLift example (Single cell) | +| `geolift_multi_cell` | GeoLift example (Multi cell) | +| `risk` | AJR2001 dataset | +| `nhefs` | NHEFS dataset | +| `schoolReturns` | Schooling Returns dataset | +| `pisa18` | PISA 2018 Sample Scale | +| `nets` | Nets DataFrame | +| `lalonde` | Lalonde dataset | + +## Usage Example + +```python +import causalpy as cp + +# Load the 'did' dataset for a Difference-in-Differences demo +df = cp.load_data("did") +print(df.head()) +``` diff --git a/.cursor/commands/causalpy_estimators.md b/.cursor/commands/causalpy_estimators.md new file mode 100644 index 00000000..2f8a3c30 --- /dev/null +++ b/.cursor/commands/causalpy_estimators.md @@ -0,0 +1,40 @@ +# Causal Estimators + +This skill covers the functions and methods used to estimate and summarize causal effects after fitting a model. + +## Core Estimator Methods + +Most experiment classes in CausalPy (like `DifferenceInDifferences`, `InterruptedTimeSeries`, `SyntheticControl`) provide standard methods for retrieving impact estimates. + +### `summary(round_to=2)` +Prints a summary of the main results and model coefficients. +* `round_to`: Number of decimal places to round results to. + +### `print_coefficients(round_to=2)` +Prints the coefficients of the underlying model (e.g., regression coefficients). + +## Calculation Functions + +The `BaseExperiment` and specific experiment classes use these internal calculations: + +### `calculate_impact(y, y_hat)` +Calculates the causal impact (lift) by subtracting the counterfactual prediction (`y_hat`) from the observed outcome (`y`). +* **Formula**: `impact = y - y_hat` +* **Bayesian Models**: Returns a posterior distribution of the impact. +* **OLS Models**: Returns point estimates. + +### `calculate_cumulative_impact(impact)` +Calculates the cumulative sum of the impact over time. Useful for understanding the total effect of an intervention over a period. + +## Plotting Results + +Standard plotting methods are available on the experiment objects: + +* `plot()`: Dispatches to `_bayesian_plot` or `_ols_plot` depending on the model type. +* Returns a `matplotlib` figure and axes. + +## Scikit-Learn Compatibility + +CausalPy allows using `scikit-learn` estimators via the `SkLearnAdaptor` (or similar wrappers implied by `RegressorMixin` usage). +* **Pre-Post Fit**: Fits on pre-data, predicts on post-data. +* **Coefficients**: Standard sklearn `coef_` and `intercept_` are accessed. diff --git a/.cursor/commands/causalpy_extras.md b/.cursor/commands/causalpy_extras.md new file mode 100644 index 00000000..d4e89172 --- /dev/null +++ b/.cursor/commands/causalpy_extras.md @@ -0,0 +1,472 @@ +# Causal Extras + +This file is currently empty and reserved for additional custom functions, plotting utilities, or extended estimators that are not part of the core CausalPy library but are useful for analysis. + +## Placebo-in-time + +### Overview + +The `PlaceboAnalysis` class implements a placebo-in-time sensitivity analysis for causal inference experiments. This technique helps validate causal claims by testing whether the intervention effect appears in periods where no intervention actually occurred. + +### When to Use + +Use `PlaceboAnalysis` when you want to: + +1. **Validate causal claims**: Test if your model would detect spurious effects in pre-intervention periods where no treatment occurred +2. **Check model specification**: Verify that your model isn't picking up pre-existing trends or patterns that could be mistaken for treatment effects +3. **Assess robustness**: Demonstrate that the observed effect is specific to the actual intervention period and not a general pattern in the data +4. **Strengthen inference**: Provide additional evidence that the treatment effect is real by showing no effects in placebo periods + +### How It Works + +The analysis works by: +1. Taking only the pre-intervention data +2. Creating multiple "folds" by simulating fake intervention periods at different points in time +3. Running your causal experiment on each fold as if the fake intervention were real +4. Comparing the placebo effects to the actual treatment effect + +If placebo effects are consistently smaller or non-significant compared to the actual effect, this strengthens your causal claim. If placebo effects are similar to the actual effect, this suggests your model may be detecting spurious patterns rather than true causal effects. + +### Key Features + +- **Model-agnostic**: Works with any CausalPy experiment type (InterruptedTimeSeries, SyntheticControl, DifferenceInDifferences, etc.) +- **Factory pattern**: Accepts a factory function that creates your experiment, allowing full control over model configuration +- **Flexible**: Supports both PyMC and scikit-learn models through the factory function +- **Automated**: Handles data splitting, fold creation, and model fitting automatically + + +```python + +from typing import Any, Callable +from pydantic import BaseModel +import logging + +logger = logging.getLogger(__name__) + +class PlaceboAnalysis(BaseModel): + """ + Run sensitivity analysis for any causalpy experiment using a factory pattern. + + The factory function allows complete flexibility in choosing and configuring + any causalpy experiment type (SyntheticControl, InterruptedTimeSeries, + DifferenceInDifferences, RegressionDiscontinuity, etc.). + + Parameters + ---------- + experiment_factory : Callable + A function that creates and returns a fitted causalpy experiment. + Signature: (dataset: pd.DataFrame, treatment_time: pd.Timestamp, + treatment_time_end: pd.Timestamp) -> causalpy result + dataset : pd.DataFrame + The full dataset with datetime index + intervention_start_date : str + Start date of the intervention period (YYYY-MM-DD format) + intervention_end_date : str + End date of the intervention period (YYYY-MM-DD format) + n_cuts : int + Number of cuts for cross-validation (n_cuts - 1 folds will be created) + + Examples + -------- + # InterruptedTimeSeries factory + >>> def its_factory(dataset, treatment_time, treatment_time_end): + ... formula = "tucupita ~ 1 + krasnoyarsk + shared_trend" + ... return cp.InterruptedTimeSeries( + ... dataset, + ... treatment_time, + ... formula=formula, + ... model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": 42}) + ... ) + + # SyntheticControl factory + >>> def sc_factory(dataset, treatment_time, treatment_time_end): + ... return cp.SyntheticControl( + ... dataset, + ... treatment_time, + ... control_units=["krasnoyarsk"], + ... treated_units=["tucupita"], + ... model=cp.pymc_models.WeightedSumFitter(sample_kwargs={"random_seed": 42}) + ... ) + + # Run sensitivity analysis + >>> sensitivity = PlaceboAnalysis( + ... experiment_factory=its_factory, + ... dataset=df.set_index("date"), + ... intervention_start_date="2024-01-01", + ... intervention_end_date="2024-01-30", + ... n_cuts=3 + ... ) + >>> results = sensitivity.run() + """ + + model_config = {"arbitrary_types_allowed": True} + + experiment_factory: Callable + dataset: pd.DataFrame + intervention_start_date: str + intervention_end_date: str + n_cuts: int = 2 + + def _validate_cuts(self, n_cuts: int) -> None: + """Validate that n_cuts is at least 2.""" + if n_cuts < 2: + raise ValueError("n_cuts must be >= 2 (n_cuts - 1 folds will be created).") + + def _prepare_pre_data(self, treatment_time: pd.Timestamp) -> pd.DataFrame: + """Extract pre-intervention data.""" + pre_df = self.dataset.loc[self.dataset.index < treatment_time].copy() + if pre_df.empty: + raise ValueError("No observations strictly before treatment_time in dataset.") + return pre_df + + def _calculate_intervention_length( + self, treatment_time: pd.Timestamp, treatment_time_end: pd.Timestamp + ) -> pd.Timedelta: + """Calculate the length of the intervention period.""" + treatment_time = pd.Timestamp(treatment_time) + treatment_time_end = pd.Timestamp(treatment_time_end) + intervention_length = treatment_time_end - treatment_time + if intervention_length <= pd.Timedelta(0): + raise ValueError("treatment_time_end must be after treatment_time to compute a positive intervention length.") + return intervention_length + + def _validate_sufficient_data( + self, + pre_df: pd.DataFrame, + treatment_time: pd.Timestamp, + intervention_length: pd.Timedelta, + n_cuts: int, + ) -> None: + """Validate that there's sufficient pre-intervention data for the requested folds.""" + pre_start = pre_df.index.min() + earliest_needed = treatment_time - (n_cuts - 1) * intervention_length + if pre_start > earliest_needed: + max_cuts = 1 + int((treatment_time - pre_start) // intervention_length) + raise ValueError( + "Not enough pre-period for requested folds. " + f"Earliest required: {earliest_needed.date()}, available starts: {pre_start.date()}. " + f"Try n_cuts <= {max_cuts}." + ) + + def _create_fold_data( + self, + pre_df: pd.DataFrame, + fold: int, + n_cuts: int, + treatment_time: pd.Timestamp, + intervention_length: pd.Timedelta, + ) -> tuple[pd.DataFrame, pd.Timestamp, pd.Timestamp]: + """Create data for a specific fold.""" + pseudo_start = treatment_time - (n_cuts - fold) * intervention_length + pseudo_end = pseudo_start + intervention_length + fold_df = pre_df.loc[pre_df.index < pseudo_end].sort_index() + + pre_mask = fold_df.index < pseudo_start + post_mask = (fold_df.index >= pseudo_start) & (fold_df.index < pseudo_end) + + if pre_mask.sum() == 0 or post_mask.sum() == 0: + raise ValueError( + f"Fold {fold}: insufficient data. pre_n={pre_mask.sum()}, post_n={post_mask.sum()} " + f"for window [{pseudo_start} .. {pseudo_end})." + ) + + return fold_df, pseudo_start, pseudo_end + + def _fit_model( + self, fold_df: pd.DataFrame, pseudo_start: pd.Timestamp, pseudo_end: pd.Timestamp + ) -> Any: + """ + Fit the experiment using the provided factory function. + + The factory receives the fold data and time boundaries, and returns + a fitted causalpy experiment result. + """ + logger.info(f"Fitting model for fold with treatment_time={pseudo_start}, treatment_time_end={pseudo_end}") + return self.experiment_factory(fold_df, pseudo_start, pseudo_end) + + def run(self) -> list[dict[str, Any]]: + """ + Run the sensitivity analysis across all folds. + + Returns + ------- + list[dict[str, Any]] + A list of dictionaries, one per fold, containing: + - fold: fold number + - pseudo_start: start of pseudo-intervention period + - pseudo_end: end of pseudo-intervention period + - result: the fitted causalpy experiment result + """ + n_cuts = self.n_cuts + treatment_time = pd.Timestamp(self.intervention_start_date) + treatment_time_end = pd.Timestamp(self.intervention_end_date) + + self._validate_cuts(n_cuts) + pre_df = self._prepare_pre_data(treatment_time) + intervention_length = self._calculate_intervention_length(treatment_time, treatment_time_end) + self._validate_sufficient_data(pre_df, treatment_time, intervention_length, n_cuts) + + results: list[dict[str, Any]] = [] + for fold in range(1, n_cuts): + fold_df, pseudo_start, pseudo_end = self._create_fold_data( + pre_df, fold, n_cuts, treatment_time, intervention_length + ) + + model_result = self._fit_model(fold_df, pseudo_start, pseudo_end) + + results.append( + { + "fold": fold, + "pseudo_start": pseudo_start, + "pseudo_end": pseudo_end, + "result": model_result, + } + ) + + return results +``` +Done. + +### Code example + +```python +# Define a factory function for InterruptedTimeSeries +def its_factory(dataset, treatment_time, treatment_time_end): + """ + Factory for InterruptedTimeSeries experiment. + + Note: treatment_time_end is available but not used by ITS. + It's provided for API consistency across all experiment types. + """ + formula = "tucupita ~ 1 + krasnoyarsk" + return cp.InterruptedTimeSeries( + dataset, + treatment_time, + formula=formula, + model=cp.pymc_models.LinearRegression(sample_kwargs={"random_seed": seed}) + ) + +# Run sensitivity analysis with ITS +sensitivity_its = PlaceboAnalysis( + experiment_factory=its_factory, + dataset=df.set_index("date"), + intervention_start_date=treatment_time.strftime("%Y-%m-%d"), + intervention_end_date=df["date"].iloc[-1].strftime("%Y-%m-%d"), + n_cuts=4 +) + +# Execute the sensitivity analysis +results_its = sensitivity_its.run() +``` + +### Visualize: in-time posterior variability +Here you will plot the results over each fold (placebo validation) + +```python +for r in results_its: + fig, ax = r["result"].plot(plot_predictors=True) +plt.show() +``` + +### Visualize: Posterior cumulative distribution per placebo-in-time + +```python +sensitivity_post_impact = xr.concat( + [ + r["result"] + .post_impact.sum("obs_ind") # sum over days in pseudo window + .isel(treated_units=0) + .stack(sample=("chain", "draw")) + for r in results_its + ], + dim="fold", +) + +# Summaries per fold +m_obs = sensitivity_post_impact.values # shape (fold, sample) +n_samples = sensitivity_post_impact.sizes["sample"] +n_folds = sensitivity_post_impact.sizes["fold"] + +# Convert sample coordinate to numeric for plotting +sensitivity_post_impact_numeric = sensitivity_post_impact.assign_coords( + sample=np.arange(len(sensitivity_post_impact.sample)) +) + +# Calculate and print means for each fold +fold_means = [] +for i in range(n_folds): + fold_mean = sensitivity_post_impact_numeric.isel(fold=i).mean().item() + fold_means.append(fold_mean) + # print(f'Fold {i} mean: {fold_mean:.3f}') + +# Plot histograms for all folds in the same plot +fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + +for i in range(n_folds): + fold_data = sensitivity_post_impact_numeric.isel(fold=i) + fold_data.plot.hist(ax=ax, alpha=0.7, label=f'Fold {i} (mean: {fold_means[i]:.1f})') + +ax.set_title('Sensitivity Post Impact - All Folds') +ax.set_xlabel('Sensitivity Post Impact') +ax.legend() +plt.show() +``` + +### Status quo distribution model + +This section demonstrates how to build a hierarchical Bayesian model to characterize the "status quo" distribution of placebo effects. This is useful when you want to: + +1. **Quantify baseline variability**: Understand the natural variation in effects that occur even without a true intervention +2. **Create a reference distribution**: Establish a probabilistic baseline against which to compare your actual treatment effect +3. **Account for uncertainty**: Properly propagate uncertainty from individual placebo analyses into a meta-level distribution +4. **Make probabilistic statements**: Calculate probabilities like P(actual effect > status quo) to strengthen causal claims + +**Prerequisites:** +- You must have already run `PlaceboAnalysis` and stored results in a variable (e.g., `results_its`) +- You need `sensitivity_post_impact`, an xarray DataArray with dimensions `("fold", "sample")` containing posterior draws of cumulative effects for each placebo fold +- Required imports: `pymc as pm`, `arviz as az`, `numpy as np` + +**What the model does:** +The hierarchical model treats each placebo fold's posterior mean as a noisy measurement of a latent "true" placebo effect for that fold. It then estimates: +- `mu_status_quo`: The average effect under status quo (no real intervention) +- `tau_status_quo`: The variability in effects across different time periods under status quo +- `theta_new`: Predicted effect for a new hypothetical null period (posterior predictive) + +This allows you to compare your actual intervention effect against the full status quo distribution, not just individual placebo tests. + + +```python +# ------------------------------------------------------------------- +# 1. Extract per-fold summaries from sensitivity_post_impact +# sensitivity_post_impact: dims ("fold", "sample") +# Each entry is a posterior draw of total effect in that placebo window. +# ------------------------------------------------------------------- + +n_folds = sensitivity_post_impact.sizes["fold"] +n_samples = sensitivity_post_impact.sizes["sample"] + +# Per-fold posterior means and standard deviations of total effects +fold_means = sensitivity_post_impact.mean(dim="sample").values # shape (n_folds,) +fold_sds = sensitivity_post_impact.std(dim="sample").values # shape (n_folds,) + +# Guard against degenerate SDs +min_sd_value = 1e-6 +fold_sds = np.where(fold_sds < min_sd_value, min_sd_value, fold_sds) + +coords_meta = {"fold": np.arange(n_folds)} + +# Empirical-Bayes style hyperpriors (transparent: based on fold_means dispersion) +prior_mu_center = float(np.nanmean(fold_means)) +prior_mu_scale = float(np.nanstd(fold_means)) if np.nanstd(fold_means) > 0.0 else 1.0 + +# ------------------------------------------------------------------- +# 2. Hierarchical status-quo model over placebo totals +# +# fold_true_total_effect_j ~ Normal(mu_status_quo, tau_status_quo^2) +# observed_fold_means_j ~ Normal(fold_true_total_effect_j, fold_sds_j^2) +# +# Critically: we use fold_sds as measurement error (posterior uncertainty +# from the first-stage models), NOT divided by sqrt(ESS). This treats the +# placebo posteriors as noisy estimates of latent true effects. +# ------------------------------------------------------------------- + +n_chains = 4 +draws_per_chain_meta = n_samples // n_chains # keep similar total draws scale + +with pm.Model(coords=coords_meta) as meta_status_quo_model: + # Data containers + observed_fold_means = pm.Data( + "observed_fold_means", + fold_means, + dims="fold", + ) + observed_fold_sd = pm.Data( + "observed_fold_sd", + fold_sds, + dims="fold", + ) + + # Hyperpriors for status-quo distribution of true placebo totals + mu_status_quo = pm.Normal( + "mu_status_quo", + mu=prior_mu_center, + sigma=5.0 * prior_mu_scale, + ) + + tau_status_quo = pm.HalfNormal( + "tau_status_quo", + sigma=2.0 * prior_mu_scale, + ) + + # Random effects for each placebo fold + fold_standard_normal = pm.Normal( + "fold_standard_normal", + mu=0.0, + sigma=1.0, + dims="fold", + ) + + # Latent true total effect per fold (status-quo under no intended intervention) + fold_true_total_effect = pm.Deterministic( + "fold_true_total_effect", + mu_status_quo + tau_status_quo * fold_standard_normal, + dims="fold", + ) + + # Likelihood: observed posterior means as noisy measurements of latent fold_true_total_effect + likelihood_fold_means = pm.Normal( + "likelihood_fold_means", + mu=fold_true_total_effect, + sigma=observed_fold_sd, + observed=observed_fold_means, + dims="fold", + ) + + idata_meta_status_quo = pm.sample( + draws=draws_per_chain_meta, + tune=1000, + chains=n_chains, + target_accept=0.97, + nuts_sampler="nutpie", + ) + +# Optional: quick diagnostic +az.summary(idata_meta_status_quo, var_names=["mu_status_quo", "tau_status_quo"]) + +# ------------------------------------------------------------------- +# 3. Posterior predictive: status-quo distribution for a NEW null period +# +# We want draws from: +# theta_new ~ Normal(mu_status_quo, tau_status_quo^2) +# +# This represents the latent true total effect we would see in a new +# period with NO intended intervention, given the placebo evidence. +# ------------------------------------------------------------------- + +with meta_status_quo_model: + meta_status_quo_model.add_coords({"new_period": np.arange(1)}) + + theta_new = pm.Normal( + "theta_new", + mu=mu_status_quo, + sigma=tau_status_quo, + dims="new_period", + ) + + posterior_predictive_status_quo = pm.sample_posterior_predictive( + idata_meta_status_quo, + var_names=["theta_new"], + ) + +# Extract draws of theta_new into a 1D numpy array +theta_new_samples = ( + posterior_predictive_status_quo["posterior_predictive"]["theta_new"] + .stack(sample=("chain", "draw")) + .values + .squeeze() +) # shape: (n_draws_meta,) + +ax.hist( + theta_new_samples, bins=40, density=True, alpha=0.6, label="θ_new ~ N(μ, τ)" +) +``` diff --git a/.cursor/commands/causalpy_marimo.md b/.cursor/commands/causalpy_marimo.md new file mode 100644 index 00000000..9fac84e2 --- /dev/null +++ b/.cursor/commands/causalpy_marimo.md @@ -0,0 +1,308 @@ +I am currently editing a marimo notebook. +You can read or write to the notebook at any notebook under `/.marimo/**` + +If you make edits to the notebook, only edit the contents inside the function decorator with @app.cell. +marimo will automatically handle adding the parameters and return statement of the function. For example, +for each edit, just return: + +``` +@app.cell +def _(): + +return +``` + +## Marimo fundamentals + +Marimo is a reactive notebook that differs from traditional notebooks in key ways: + +- Cells execute automatically when their dependencies change +- Variables cannot be redeclared across cells +- The notebook forms a directed acyclic graph (DAG) +- The last expression in a cell is automatically displayed +- UI elements are reactive and update the notebook automatically + +## Code Requirements + +1. All code must be complete and runnable +2. Follow consistent coding style throughout +3. Include descriptive variable names and helpful comments +4. Import all modules in the first cell, always including `import marimo as mo` +5. Never redeclare variables across cells +6. Ensure no cycles in notebook dependency graph +7. The last expression in a cell is automatically displayed, just like in Jupyter notebooks. +8. Don't include comments in markdown cells +9. Don't include comments in SQL cells +10. Never define anything using `global`. +11. **IMPORTANT**: Do not mix markdown and plots in the same cell. Create separate cells: one for `mo.md(...)` and a subsequent one for the plot. +12. **IMPORTANT**: Avoid using Altair. Prefer Matplotlib, Seaborn, or Arviz for visualizations. + +## Reactivity + +Marimo's reactivity means: + +- When a variable changes, all cells that use that variable automatically re-execute +- UI elements trigger updates when their values change without explicit callbacks +- UI element values are accessed through `.value` attribute +- You cannot access a UI element's value in the same cell where it's defined +- Cells prefixed with an underscore (e.g. _my_var) are local to the cell and cannot be accessed by other cells + +## Best Practices + + +- Use pandas for data manipulation +- Implement proper data validation +- Handle missing values appropriately +- Use efficient data structures +- A variable in the last expression of a cell is automatically displayed as a table + + + +- **IMPORTANT**: Do not use Altair. Use Matplotlib, Seaborn, or Arviz. +- **IMPORTANT**: Do not mix markdown and plots. Always separate them into different cells. +- For matplotlib: use `plt.gca()` as the last expression instead of `plt.show()`, or return the figure/axes object directly. +- Include proper labels, titles, and color schemes +- Make visualizations interactive where appropriate (using Marimo UI elements to control plot parameters) + + + +- Access UI element values with .value attribute (e.g., slider.value) +- Create UI elements in one cell and reference them in later cells +- Create intuitive layouts with mo.hstack(), mo.vstack(), and mo.tabs() +- Prefer reactive updates over callbacks (marimo handles reactivity automatically) +- Group related UI elements for better organization + + + +- When writing duckdb, prefer using marimo's SQL cells, which start with df = mo.sql(f"""""") for DuckDB, or df = mo.sql(f"""""", engine=engine) for other SQL engines. +- See the SQL with duckdb example for an example on how to do this +- Don't add comments in cells that use mo.sql() +- Consider using `vega_datasets` for common example datasets + + +## Troubleshooting + +Common issues and solutions: +- Circular dependencies: Reorganize code to remove cycles in the dependency graph +- UI element value access: Move access to a separate cell from definition +- Visualization not showing: Ensure the visualization object is the last expression + +## Available UI elements + +- `mo.ui.button(value=None, kind='primary')` +- `mo.ui.run_button(label=None, tooltip=None, kind='primary')` +- `mo.ui.checkbox(label='', value=False)` +- `mo.ui.date(value=None, label=None, full_width=False)` +- `mo.ui.dropdown(options, value=None, label=None, full_width=False)` +- `mo.ui.file(label='', multiple=False, full_width=False)` +- `mo.ui.number(value=None, label=None, full_width=False)` +- `mo.ui.radio(options, value=None, label=None, full_width=False)` +- `mo.ui.refresh(options: List[str], default_interval: str)` +- `mo.ui.slider(start, stop, value=None, label=None, full_width=False, step=None)` +- `mo.ui.range_slider(start, stop, value=None, label=None, full_width=False, step=None)` +- `mo.ui.table(data, columns=None, on_select=None, sortable=True, filterable=True)` +- `mo.ui.text(value='', label=None, full_width=False)` +- `mo.ui.text_area(value='', label=None, full_width=False)` +- `mo.ui.data_explorer(df)` +- `mo.ui.dataframe(df)` +- `mo.ui.plotly(plotly_figure)` +- `mo.ui.tabs(elements: dict[str, mo.ui.Element])` +- `mo.ui.array(elements: list[mo.ui.Element])` +- `mo.ui.form(element: mo.ui.Element, label='', bordered=True)` + +## Layout and utility functions + +- `mo.md(text)` - display markdown +- `mo.stop(predicate, output=None)` - stop execution conditionally +- `mo.output.append(value)` - append to the output when it is not the last expression +- `mo.output.replace(value)` - replace the output when it is not the last expression +- `mo.Html(html)` - display HTML +- `mo.image(image)` - display an image +- `mo.hstack(elements)` - stack elements horizontally +- `mo.vstack(elements)` - stack elements vertically +- `mo.tabs(elements)` - create a tabbed interface + +## Examples + + + +@app.cell +def _(): + mo.md(""" + # Hello world + This is a _markdown_ **cell**. + """) + return + + + + + +@app.cell +def _(): + import marimo as mo + import matplotlib.pyplot as plt + import numpy as np + return + +@app.cell +def _(): + n_points = mo.ui.slider(10, 100, value=50, label="Number of points") + n_points + return + +@app.cell +def _(): + # Separate cell for plotting logic + x = np.random.rand(n_points.value) + y = np.random.rand(n_points.value) + + fig, ax = plt.subplots() + ax.scatter(x, y, alpha=0.7) + ax.set_title(f"Scatter plot with {n_points.value} points") + ax.set_xlabel('X axis') + ax.set_ylabel('Y axis') + + # Return the axes or figure to display + ax + return + + + + + +@app.cell +def _(): + import marimo as mo + import polars as pl + from vega_datasets import data + return + +@app.cell +def _(): + cars_df = pl.DataFrame(data.cars()) + mo.ui.data_explorer(cars_df) + return + + + + + +@app.cell +def _(): + import marimo as mo + import pandas as pd + import seaborn as sns + import matplotlib.pyplot as plt + return + +@app.cell +def _(): + # Load data (using pandas as seaborn works natively with it) + iris = pd.read_csv("https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv") + return + +@app.cell +def _(): + species_selector = mo.ui.dropdown( + options=["All"] + iris["species"].unique().tolist(), + value="All", + label="Species", + ) + x_feature = mo.ui.dropdown( + options=iris.select_dtypes(include=['float', 'int']).columns.tolist(), + value="sepal_length", + label="X Feature", + ) + y_feature = mo.ui.dropdown( + options=iris.select_dtypes(include=['float', 'int']).columns.tolist(), + value="sepal_width", + label="Y Feature", + ) + mo.hstack([species_selector, x_feature, y_feature]) + return + +@app.cell +def _(): + # Reactive plot generation + filtered_data = iris if species_selector.value == "All" else iris[iris["species"] == species_selector.value] + + fig, ax = plt.subplots(figsize=(6, 4)) + sns.scatterplot( + data=filtered_data, + x=x_feature.value, + y=y_feature.value, + hue='species', + ax=ax + ) + ax.set_title(f"{y_feature.value} vs {x_feature.value}") + + # Display the plot (last expression) + ax + return + + + + + +@app.cell +def _(): + mo.stop(not data.value, mo.md("No data to display")) + + if mode.value == "scatter": + mo.output.replace(render_scatter(data.value)) + else: + mo.output.replace(render_bar_chart(data.value)) + return + + + + + +@app.cell +def _(): + import marimo as mo + return + +@app.cell +def _(): + first_button = mo.ui.run_button(label="Option 1") + second_button = mo.ui.run_button(label="Option 2") + [first_button, second_button] + return + +@app.cell +def _(): + if first_button.value: + print("You chose option 1!") + elif second_button.value: + print("You chose option 2!") + else: + print("Click a button!") + return + + + + + +@app.cell +def _(): + import marimo as mo + import polars as pl + return + +@app.cell +def _(): + weather = pl.read_csv('https://raw.githubusercontent.com/vega/vega-datasets/refs/heads/main/data/weather.csv') + return + +@app.cell +def _(): + seattle_weather_df = mo.sql( + f""" + SELECT * FROM weather WHERE location = 'Seattle'; + """ + ) + return + + diff --git a/.cursor/commands/causalpy_methods.md b/.cursor/commands/causalpy_methods.md new file mode 100644 index 00000000..01c0243d --- /dev/null +++ b/.cursor/commands/causalpy_methods.md @@ -0,0 +1,189 @@ +# Causal Methods + +This command provides information and usage examples for the core causal inference methods available in CausalPy: +1. Difference-in-Differences (DiD) +2. Interrupted Time Series (ITS) +3. Synthetic Control (SCG) + +For details on how to retrieve estimates, summarize results, and plot outputs after fitting these models, please refer to the [Causal Estimators](causalpy_estimators.md) command (`@causalpy_estimators`). + +--- + +## 1. Causal Difference-in-Differences (DiD) + +Difference-in-Differences (DiD) estimates the causal effect of a treatment by comparing the changes in outcomes over time between a treatment group and a control group. + +### Class: `DifferenceInDifferences` + +```python +causalpy.experiments.DifferenceInDifferences( + data, + formula, + time_variable_name, + group_variable_name, + post_treatment_variable_name="post_treatment", + model=None, + **kwargs +) +``` + +#### Parameters +* **`data`** (`pd.DataFrame`): Input dataframe containing panel data. +* **`formula`** (`str`): Statistical formula (e.g., `"y ~ 1 + group * post_treatment"`). +* **`time_variable_name`** (`str`): Column name for the time variable. +* **`group_variable_name`** (`str`): Column name for the group indicator (0=Control, 1=Treated). **Must be dummy coded**. +* **`post_treatment_variable_name`** (`str`): Column name indicating the post-treatment period (0=Pre, 1=Post). Default is `"post_treatment"`. +* **`model`**: A PyMC model (e.g., `cp.pymc_models.LinearRegression`) or a Scikit-Learn Regressor. + +#### How it Works +1. **Fit**: The model fits all available data (pre/post, treatment/control). +2. **Counterfactual**: The counterfactual is predicted by setting the interaction term between `group` and `post_treatment` to 0 (i.e., what would have happened if the treatment group had not been treated in the post-period). +3. **Impact**: The causal impact is the coefficient of the interaction term (in linear models) or the difference between observed and counterfactual. + +#### Example + +```python +import causalpy as cp +import causalpy.pymc_models as cp_pymc + +# Load data +df = cp.load_data("did") + +# Run DiD +result = cp.DifferenceInDifferences( + df, + formula="y ~ 1 + group*post_treatment", + time_variable_name="t", + group_variable_name="group", + model=cp_pymc.LinearRegression(sample_kwargs={"target_accept": 0.9}) +) + +# Summarize +result.summary() + +# Plot +result.plot() +``` + +#### Key Assumptions +* **Parallel Trends**: Trends in the outcome variable would be the same for both groups in the absence of treatment. + +--- + +## 2. Causal Interrupted Time Series (ITS) + +Interrupted Time Series (ITS) analyzes the effect of an intervention on a single time series by comparing the trend before and after the intervention. + +### Class: `InterruptedTimeSeries` + +```python +causalpy.experiments.InterruptedTimeSeries( + data, + treatment_time, + formula, + model=None, + **kwargs +) +``` + +#### Parameters +* **`data`** (`pd.DataFrame`): Input dataframe. Index should ideally be a `pd.DatetimeIndex`. +* **`treatment_time`** (`Union[int, float, pd.Timestamp]`): The point in time when the intervention occurred. +* **`formula`** (`str`): Statistical formula (e.g., `"y ~ 1 + t + C(month)"`). +* **`model`**: A PyMC model (e.g., `cp.pymc_models.LinearRegression`, `cp.pymc_models.BayesianBasisExpansionTimeSeries`) or a Scikit-Learn Regressor. + +#### How it Works +1. **Split**: Data is split into pre-intervention and post-intervention sets based on `treatment_time`. +2. **Fit**: The model is trained **only on the pre-intervention data**. +3. **Predict**: The fitted model predicts the outcome for the post-intervention period (the counterfactual). +4. **Impact**: The causal impact is the difference between the observed post-intervention data and the model's counterfactual predictions. + +#### Example + +```python +import causalpy as cp +import causalpy.pymc_models as cp_pymc +import pandas as pd + +# Load data +df = cp.load_data("its") +df["date"] = pd.to_datetime(df["date"]) +df.set_index("date", inplace=True) + +treatment_time = pd.to_datetime("2017-01-01") + +# Run ITS +result = cp.InterruptedTimeSeries( + df, + treatment_time, + formula="y ~ 1 + t + C(month)", + model=cp_pymc.LinearRegression() +) + +# Summary and Plot +result.summary() +result.plot() +``` + +#### Key Considerations +* **Seasonality**: Include seasonal components (e.g., `C(month)`) in the formula if the data exhibits seasonal patterns. +* **Trends**: Ensure the model captures the underlying trend (e.g., linear time trend `t`) to avoid attributing secular trends to the intervention. + +--- + +## 3. Causal Synthetic Control (SCG) + +Synthetic Control constructs a "synthetic" counterfactual unit using a weighted combination of untreated control units that best matches the treated unit's pre-intervention trajectory. + +### Class: `SyntheticControl` + +```python +causalpy.experiments.SyntheticControl( + data, + treatment_time, + control_units, + treated_units, + model=None, + **kwargs +) +``` + +#### Parameters +* **`data`** (`pd.DataFrame`): Input dataframe containing panel data. +* **`treatment_time`** (`Union[int, float, pd.Timestamp]`): The time of intervention. +* **`control_units`** (`List[str]`): List of column names representing the control units. +* **`treated_units`** (`List[str]`): List of column names representing the treated unit(s). +* **`model`**: A PyMC model (typically `cp.pymc_models.WeightedSumFitter`) or a Scikit-Learn Regressor. + +#### How it Works +1. **Fit**: The model learns weights for the `control_units` to approximate the `treated_units` using **only pre-intervention data**. +2. **Predict**: These weights are applied to the `control_units` in the post-intervention period to generate the synthetic counterfactual. +3. **Impact**: The difference between the observed treated unit and the synthetic counterfactual. + +#### Example + +```python +import causalpy as cp +import causalpy.pymc_models as cp_pymc + +# Load data +df = cp.load_data("sc") +treatment_time = 70 + +# Run Synthetic Control +result = cp.SyntheticControl( + df, + treatment_time, + control_units=["a", "b", "c", "d", "e"], + treated_units=["actual"], + model=cp_pymc.WeightedSumFitter() +) + +# Summary and Plot +result.summary() +result.plot() +``` + +#### Model Selection +* **`WeightedSumFitter`**: Enforces that weights sum to 1 and are non-negative (standard Synthetic Control constraints). +* **`LinearRegression`**: Can be used but allows negative weights and intercept, effectively relaxing the standard SC constraints (sometimes called "Geolift" style or unconstrained SC). diff --git a/.cursor/commands/causalpy_research.md b/.cursor/commands/causalpy_research.md new file mode 100644 index 00000000..b615f9b6 --- /dev/null +++ b/.cursor/commands/causalpy_research.md @@ -0,0 +1,47 @@ +# Causal Research + +This skill is designed to help users "think deep" about their causal inference problem and select the most appropriate method. + +## Decision Framework + +When deciding on a method, ask the following questions: + +1. **Do you have a control group?** + * **Yes**: Proceed to check the structure of the control/treatment units. + * **No**: Consider methods that rely on time-series projection (e.g., Interrupted Time Series). + +2. **What is the unit structure?** + * **Single Treated Unit**: + * With multiple control units: **Synthetic Control (SCG)** is often best. + * With no control units: **Interrupted Time Series (ITS)**. + * **Multiple Treated Units**: + * With a control group: **Difference-in-Differences (DiD)**. + +3. **What is the time structure?** + * **Panel Data**: Data observed over time for multiple units. Required for DiD and SCG. + * **Time Series**: Data observed over time for a single unit (or aggregated units). Required for ITS. + +## Method Selection Guide + +### Difference-in-Differences (DiD) +* **Best for**: Measuring the effect of a treatment by comparing the change in outcome over time between a treatment group and a control group. +* **Key Assumption**: **Parallel Trends**. In the absence of treatment, the difference between the treatment and control group is constant over time. +* **Data Requirement**: Panel data with a clear pre/post intervention period and a defined control group. + +### Interrupted Time Series (ITS) +* **Best for**: Evaluating the effect of an intervention on a single unit (or population) by analyzing the change in level and trend of the outcome after the intervention. +* **Key Assumption**: The pre-intervention trend would have continued unchanged in the absence of the intervention. +* **Data Requirement**: High-frequency time-series data (e.g., daily, monthly) with a clear intervention date. + +### Synthetic Control (SCG) +* **Best for**: Estimating the effect of an intervention on a single treated unit (e.g., a city, state, or country) using a weighted combination of control units. +* **Key Assumption**: The control units can accurately reconstruct the treated unit's pre-intervention trajectory. +* **Data Requirement**: Panel data with a long pre-intervention period and several potential control units that were not affected by the treatment. + +## Diagnostic Questions to Ask the User + +If the user is unsure, ask: +1. "Does your data include a group that was *never* treated?" +2. "Do you have data collected over time (e.g., days, months, years)?" +3. "Is the treatment applied to a single entity (e.g., one store, one country) or many?" +4. "Do you suspect other events happened at the same time as the treatment that could affect the outcome?" diff --git a/.cursor/commands/commit.md b/.cursor/commands/commit.md new file mode 100644 index 00000000..e801226b --- /dev/null +++ b/.cursor/commands/commit.md @@ -0,0 +1,40 @@ +# Commit Changes + +You are tasked with creating git commits for the changes made during this session. + +## Process: + +1. **Think about what changed:** + - Review the conversation history and understand what was accomplished + - Run `git status` to see current changes + - Run `git diff` to understand the modifications + - Consider whether changes should be one commit or multiple logical commits + +2. **Plan your commit(s):** + - Identify which files belong together + - Draft clear, descriptive commit messages + - Use imperative mood in commit messages + - Focus on why the changes were made, not just what + +3. **Present your plan to the user:** + - List the files you plan to add for each commit + - Show the commit message(s) you'll use + - Ask: "I plan to create [N] commit(s) with these changes. Shall I proceed?" + +4. **Execute upon confirmation:** + - Use `pre-commit run --all-files`, If rules like pre-commit are missing, solve them. + - Use `git add` with specific files (never use `-A` or `.`) + - Create commits with your planned messages + - Show the result with `git log --oneline -n [number]` + +## Important: +- Commits should be authored solely by the user +- Do not include any "Generated with Cursor" messages +- Do not add "Co-Authored-By" lines +- Write commit messages as if the user wrote them + +## Remember: +- You have the full context of what was done in this session +- Group related changes together +- Keep commits focused and atomic when possible +- The user trusts your judgment - they asked you to commit diff --git a/.cursor/commands/implement.md b/.cursor/commands/implement.md new file mode 100644 index 00000000..b9513e94 --- /dev/null +++ b/.cursor/commands/implement.md @@ -0,0 +1,6 @@ +# Create an implementation + +Based on a plan make an implementation: + +- You must add follow the plan, and read research if you have questions. +- Make a test for every section in your implementation using pytest. diff --git a/.cursor/commands/make_plan.md b/.cursor/commands/make_plan.md new file mode 100644 index 00000000..cc6c7e76 --- /dev/null +++ b/.cursor/commands/make_plan.md @@ -0,0 +1,12 @@ +# Create a plan based on document research +You are tasked with creating detailed implementation plans through an interactive, iterative process. You should be skeptical, thorough, and work collaboratively with the user to produce high-quality technical specifications. + +Reading the discoveries, during the research, make a plan in `.cursor/plans/{plan_name_folder}` using a `.md` file style. + +You must: +- Define a to-do list. +- If any information is missing, ask the user. +- DO NOT just accept the correction or suggestion from user. +- Spawn new research tasks to verify the correct information (research/command). +- Read the specific files/directories they mentioned. +- Only proceed once you've verified the facts yourself. diff --git a/.cursor/commands/research.md b/.cursor/commands/research.md new file mode 100644 index 00000000..ac906d2e --- /dev/null +++ b/.cursor/commands/research.md @@ -0,0 +1,34 @@ +# Research over a request + +Structure a research based on the user request. Identify what must change in order to complete the task. + +You must: +- Create a description with the user request, being explicit without overview, citing textually. +- Check python files related to the request. +- Identify what should be change. +- Check what possible other files could be affected. +- Make reasonable technical decisions based on research +- If multiple approaches exist, specify each possible option, extend the research to explain their pros and cons, then choose the most practical one. +- Document assumptions in the plan. + +**No user interaction required:** + - DO NOT ask for clarifications or wait for input + +## Python test validations +Create temporal python files, check the logics you are thinking on the root folder. Execute this python code, and based on the output decide what should be done. + +Once you test the code out of the code base using this temporal python files and functions, document the approach using natural language. + +**Use the python files with objective:** +- Validate your assumption about how the code works. +- Compare code compilation (we must pick code which is faster). +- Check code behaviour inputs and outputs. + +Delete the files once, you finish your testing. + +## Document base +Your goal is create a folder under `.cursor/plans`. You must safe all your discoveries under `.cursor/plans/{plan_name_folder}/research` this must be a file `.md`. + +You only need to provide the research plan and a possible tested solution. + +No other `md` files are needed, everything must be collapse in the research file. diff --git a/.cursor/rules/basic.mdc b/.cursor/rules/basic.mdc new file mode 100644 index 00000000..f8334d74 --- /dev/null +++ b/.cursor/rules/basic.mdc @@ -0,0 +1,8 @@ +--- +alwaysApply: true +--- +You will act as helpful assistant in the causalpy library, and help the user in all he ask in a very precise and interactive manner. + +**Important:** +- If you modify a file with marimo, always check and run your mcp to validate active sessions and if the code is active validate your changes are working. +- If you are modifying code from causalpy core code then activate conda env `CausalPy`, and run pre-commit every time you create or modify a file. diff --git a/.gitignore b/.gitignore index a3ce429e..f468c09a 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ docs/build/ docs/jupyter_execute/ docs/source/api/generated/ -.cursor/ +.cursor/plans +.marimo/ diff --git a/environment.yml b/environment.yml index bf0fcee4..cd0e2a3b 100644 --- a/environment.yml +++ b/environment.yml @@ -18,3 +18,4 @@ dependencies: - pymc-extras>=0.3.0 - pymc-bart - python>=3.11 + - "marimo[mcp]" diff --git a/pyproject.toml b/pyproject.toml index 833ab1ec..0523c3c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -196,3 +196,6 @@ ignore_missing_imports = true [tool.mypy-scipy] ignore_missing_imports = true + +[tool.marimo.runtime] +watcher_on_save = "autorun" From f6a9d0b77136f510a763ecf0153cf7236ac10145 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Thu, 11 Dec 2025 23:24:00 +0100 Subject: [PATCH 2/2] Update causalpy_extras.md --- .cursor/commands/causalpy_extras.md | 1 - 1 file changed, 1 deletion(-) diff --git a/.cursor/commands/causalpy_extras.md b/.cursor/commands/causalpy_extras.md index d4e89172..00b605e7 100644 --- a/.cursor/commands/causalpy_extras.md +++ b/.cursor/commands/causalpy_extras.md @@ -225,7 +225,6 @@ class PlaceboAnalysis(BaseModel): return results ``` -Done. ### Code example