Double Machine Learning, Simplified: Part 2 - Targeting & the CATE

Causal ML • Heterogeneous Effects

Double Machine Learning, Simplified: Part 2 - Targeting & the CATE

Learn how to utilize DML for estimating idiosyncratic treatment effects to enable personalized targeting.

CATELinear DMLNon-Parametric DML

Key Takeaways

1
Move beyond the ATE: estimate treatment effect heterogeneity via the CATE.
2
Linear DML uses interactions with residualized treatment to model interpretable heterogeneity.
3
Non-parametric DML learns flexible CATE functions via a weighted regression objective.
\[\; \tilde{y}=\tau(X)\,\tilde{T}+\varepsilon \;\]
Imports and Settings
import numpy as np
import pandas as pd
import statsmodels.formula.api as smf
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import graphviz
import matplotlib.pyplot as plt

np.random.seed(00)

## Helper Plots

COLORS = ["#00B0F0", "#FF0000", "#B0F000"]

def plot_effect(
    effect_true, effect_pred, save_path, figsize=(8, 5), ylim=(-10, 100)
):
    plt.figure(figsize=figsize)
    plt.scatter(effect_true, effect_pred, color=COLORS[0], s=10)
    plt.plot(
        np.sort(effect_true),
        np.sort(effect_true),
        color=COLORS[1],
        alpha=0.7,
        label="Perfect model",
    )
    plt.xlabel("True effect", fontsize=14)
    plt.ylabel("Predicted effect", fontsize=14)
    plt.legend()
    plt.savefig(save_path, format="webp", dpi=300, bbox_inches='tight')

def hist_effect(effect_true, effect_pred, save_path, figsize=(8, 5)):
    plt.figure(figsize=figsize)

    plt.hist(
        effect_pred,
        color="r",
        alpha=0.8,
        density=True,
        bins=50,
        label="Linear DML CATE Prediction",
    )
    plt.hist(
        effect_true,
        color="b",
        alpha=0.4,
        density=True,
        bins=50,
        label="True CATE",
    )

    plt.legend()
    plt.savefig(save_path, format="webp", dpi=300, bbox_inches='tight')

Introduction

This article is the 2nd in a 2 part series on simplifying and democratizing Double Machine Learning. In the 1st part, we covered the fundamentals of Double Machine Learning, along with two basic causal inference applications. Now, in pt. 2, we will extend this knowledge to turn our causal inference problem into a prediction task, wherein we predict individual level treatment effects to aid in decision making and data-driven targeting

Double Machine Learning, as we learned in part 1 of this series, is a highly flexible partially-linear causal inference method for estimating the average treatment effect (ATE) of a treatment. Specifically, it can be utilized to model highly non-linear confounding relationships in observational data (especially when our set of controls/confounders is of extremely high dimensionality) and/or to reduce the variation in our key outcome in experimental settings. Estimating the ATE is particularly useful in understanding the average impact of a specific treatment, which can be extremely useful for future decision making. However, extrapolating this treatment effect assumes a degree homogeneity in the effect; that is, regardless of the population we roll treatment out to, we anticipate the effect to be similar to the ATE. What if we are limited in the number of individuals who we can target for future rollout and thus want to understand among which subpopulations the treatment was most effective to drive highly effective rollout?

This issue described above concerns estimating treatment effect heterogeneity. That is, how does our treatment effect impact different subsets of the population? Luckily for us, DML provides a powerful framework to do exactly this. Specifically, we can make use of DML to estimate the Conditional Average Treatment Effect (CATE). First, let’s revisit our definition of the ATE, in binary and continuous cases, respectively:

\[ \begin{equation} \text{ATE}=\mathbb{E_n}[y(T=1)-y(T=0)] \tag{1} \end{equation} \]

\[ \begin{equation} \text{ATE}=\mathbb{E_n}\left[\frac{\partial y}{\partial T}\right] \tag{2} \end{equation} \]

Now with the CATE, we estimate the ATE conditional on a set of values for our covariates, \(\mathbf{X}\):

\[ \begin{equation} \text{CATE}=\mathbb{E_n}[y(T=1)-y(T=0)|\mathbf{X}=x] \tag{3} \end{equation} \]

\[ \begin{equation} \text{CATE}=\mathbb{E_n}\left[\frac{\partial y}{\partial T}\right|\mathbf{X}=x] \tag{4} \end{equation} \]

For example, if we wanted to know the treatment effect for males versus females, we can estimate the CATE conditional on the covariate being equal to each subgroup of interest. Note that we can estimate highly aggregated CATEs (i.e., at a male vs. female level), also known as Group Average Treatment Effects (GATEs), or we can allow \(\mathbf{X}\) to take on an extremely high dimensionality and thus closely estimate each individuals treatment effect. You may immediately notice the benefits in being able to do this: we can utilize this information to make highly informed decisions in future targeting of the treatment! Even more notable, we can create a CATE function to make predictions of the treatment effect on previously unexposed individuals!

Note, that there are many models that exist for estimating CATEs, which we’ll cover in a subsequent post. For now, we’ll cover two techniques within the partially linear DML formulation for estimating this CATE function; namely, Linear DML and Non-Parametric DML. Er will show how to estimate the CATE mathematically and then provide examples for each case.

Note: Unbiased estimation of the CATE still requires the exogeneity/CIA/Ignorability assumption to hold as covered in part 1.

Everything demonstrated below can and should be extended to the experimental setting (RCT or A/B Testing), where exogeneity is satisfied by construction, as covered in application 2 of part 1.

Linear DML for Estimating the CATE

Estimating the CATE in the linear DML framework is a simple extension of DML for estimating the ATE:

\[ \begin{equation} y-\mathcal{M}_y(\mathbf{X})=\beta_0+\beta_1(T-\mathcal{M}_T(\mathbf{X}))+\epsilon \tag{5} \end{equation} \]

where \(y\) is our outcome, \(T\) is our treatment, & \(\mathcal{M}_y\) and \(\mathcal{M}_T\) are both flexible ML models (our nuisance functions) to predict \(y\) and \(T\) given confounders and/or controls, \(\mathbf{X}\), respectively. To estimate the CATE function using Linear DML, we can simply include interaction terms of the treatment residuals with our covariates. Observe:

\[ \begin{equation} y-\mathcal{M}_y(\mathbf{X})=\beta_0+\beta_1(T-\mathcal{M}_T(\mathbf{X}))+(T-\mathcal{M}_T(\mathbf{X}))\mathbf{X}\mathbf{\Omega} + \epsilon \tag{6} \end{equation} \]

where \(\mathbf{\Omega}\) is the vector of coefficients for the interaction terms. Now our CATE function, call it \(\tau\), takes the form \(\tau(\mathbf{X}) = \beta_1 + \mathbf{X}\mathbf{\Omega}\), where we can predict each individuals CATE given \(\mathbf{X}\). If \(T\) is continuous, this CATE function is for a 1 unit increase in T. Note that \(\tau(\mathbf{X}) = \beta_1\) in eq. (3) where \(\tau(\mathbf{X})\) is assumed a constant. Let’s take a look at this in action!

First, let’s use the same casual DAG from part 1, where we will be looking at the effect of an individuals time spent on the website on their purchase amount, or sales, in the past month (assuming we observe all confounders).:

Code
# Create a directed graph
g = graphviz.Digraph(format="png")

# Add nodes
nodes = [
    "Age",
    "# Social Media Accounts",
    "Yrs Member",
    "Time on Website",
    "Sales",
    "Z",
]
[g.node(n) for n in nodes]

g.edge("Age", "Time on Website")
g.edge("# Social Media Accounts", "Time on Website")
g.edge("Yrs Member", "Time on Website")
g.edge("Age", "Sales")
g.edge("# Social Media Accounts", "Sales")
g.edge("Yrs Member", "Sales")
g.edge("Time on Website", "Sales", color="red")
g.edge("Z", "Sales")

g.graph_attr["dpi"] = "200"

# Render for print
g.render("data/dag1", format="webp")

Let’s then simulate this DGP using a similar process as utilized in part 1 (note that all values & data are chosen and generated arbitrarily for demonstrative purposes). Observe that we now include interaction terms in the sales DGP to model the CATE, or treatment effect heterogeneity (note that the DGP in part 1 had no treatment effect heterogeneity by construction)

Code
# Sample Size
N = 100_000

# Confounders (X)
age = np.random.randint(low=18, high=75, size=N)
num_social_media_profiles = np.random.choice(
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], size=N
)
yr_membership = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], size=N)

# Arbitrary Covariates (Z)
Z = np.random.normal(loc=50, scale=25, size=N)

# Error Terms
ε1 = np.random.normal(loc=20, scale=5, size=N)
ε2 = np.random.normal(loc=40, scale=15, size=N)

# Treatment (T = g(X) + ε1)
def T(age, num_social_media_profiles, yr_membership, ε1):
    time_on_website = np.maximum(
        10
        - 0.01 * age
        - 0.001 * age**2
        + num_social_media_profiles
        - 0.01 * num_social_media_profiles**2
        - 0.01 * (age * num_social_media_profiles)
        + 0.2 * yr_membership
        + 0.001 * yr_membership**2
        - 0.01 * (age * yr_membership)
        + 0.2 * (num_social_media_profiles * yr_membership)
        + 0.01
        * (num_social_media_profiles * np.log(age) * age * yr_membership ** (1 / 2))
        + ε1,
        0,
    )
    return time_on_website

time_on_website = T(age, num_social_media_profiles, yr_membership, ε1)

# Outcome (y = f(T,X,Z) + ε2)
def y(time_on_website, age, num_social_media_profiles, yr_membership, Z, ε2):
    sales = np.maximum(
        25
        + 5 * time_on_website  # Baseline Treatment Effect
        - 0.2 * time_on_website * age  # Heterogeneity
        + 2 * time_on_website * num_social_media_profiles  # Heterogeneity
        + 2 * time_on_website * yr_membership  # Heterogeneity
        - 0.1 * age
        - 0.001 * age**2
        + 8 * num_social_media_profiles
        - 0.1 * num_social_media_profiles**2
        - 0.01 * (age * num_social_media_profiles)
        + 2 * yr_membership
        + 0.1 * yr_membership**2
        - 0.01 * (age * yr_membership)
        + 3 * (num_social_media_profiles * yr_membership)
        + 0.1
        * (num_social_media_profiles * np.log(age) * age * yr_membership ** (1 / 2))
        + 0.5 * Z
        + ε2,
        0,
    )
    return sales

sales = y(time_on_website, age, num_social_media_profiles, yr_membership, Z, ε2)

df = pd.DataFrame(
    np.array([sales, time_on_website, age, num_social_media_profiles, yr_membership, Z]).T,
    columns=["sales","time_on_website","age","num_social_media_profiles","yr_membership","Z"],
)
Code
df
sales time_on_website age num_social_media_profiles yr_membership Z
0 3322.810046 100.647771 62.0 7.0 9.0 43.399307
1 4001.932521 112.458788 65.0 9.0 8.0 60.097765
2 119.987080 23.456745 71.0 1.0 3.0 71.465249
3 1100.177387 38.975211 18.0 8.0 2.0 21.929513
4 2027.720875 56.508789 21.0 7.0 7.0 63.506492
... ... ... ... ... ... ...
99995 514.495036 31.927945 46.0 1.0 7.0 7.327503
99996 2702.614252 96.241676 70.0 8.0 6.0 56.238647
99997 482.198339 44.362601 61.0 2.0 4.0 84.668278
99998 1802.891223 67.471281 69.0 5.0 8.0 59.031407
99999 710.743135 33.415857 43.0 10.0 0.0 58.393933

100000 rows × 6 columns

Now, to estimate our CATE function, as outlined in eq. (4), we can run:

Code
# DML Procedure for Estimating the CATE
M_sales = GradientBoostingRegressor()
M_time_on_website = GradientBoostingRegressor()

df["residualized_sales"] = df["sales"] - cross_val_predict(
    M_sales,
    df[["age", "num_social_media_profiles", "yr_membership"]],
    df["sales"],
    cv=3,
)

df["residualized_time_on_website"] = df["time_on_website"] - cross_val_predict(
    M_time_on_website,
    df[["age", "num_social_media_profiles", "yr_membership"]],
    df["time_on_website"],
    cv=3,
)

DML_model = smf.ols(
    formula="residualized_sales ~ 1 + residualized_time_on_website + residualized_time_on_website:age + residualized_time_on_website:num_social_media_profiles + residualized_time_on_website:yr_membership",
    data=df,
).fit()

print(DML_model.summary())
                            OLS Regression Results                            
==============================================================================
Dep. Variable:     residualized_sales   R-squared:                       0.742
Model:                            OLS   Adj. R-squared:                  0.742
Method:                 Least Squares   F-statistic:                 7.207e+04
Date:                Fri, 16 Jan 2026   Prob (F-statistic):               0.00
Time:                        14:53:50   Log-Likelihood:            -5.4142e+05
No. Observations:              100000   AIC:                         1.083e+06
Df Residuals:                   99995   BIC:                         1.083e+06
Df Model:                           4                                         
Covariance Type:            nonrobust                                         
==========================================================================================================================
                                                             coef    std err          t      P>|t|      [0.025      0.975]
--------------------------------------------------------------------------------------------------------------------------
Intercept                                                 -0.2065      0.172     -1.202      0.230      -0.543       0.130
residualized_time_on_website                               6.0280      0.124     48.784      0.000       5.786       6.270
residualized_time_on_website:age                          -0.1695      0.002    -85.257      0.000      -0.173      -0.166
residualized_time_on_website:num_social_media_profiles     1.7358      0.010    165.531      0.000       1.715       1.756
residualized_time_on_website:yr_membership                 1.7694      0.010    170.534      0.000       1.749       1.790
==============================================================================
Omnibus:                     3954.403   Durbin-Watson:                   1.996
Prob(Omnibus):                  0.000   Jarque-Bera (JB):            11608.184
Skew:                           0.124   Prob(JB):                         0.00
Kurtosis:                       4.651   Cond. No.                         257.
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

Here we can see that linear DML closely modeled the true DGP for the CATE (see coefficients on interaction terms in sales DGP). Let’s evaluate the performance of our CATE function by comparing the linear DML predictions to the true CATE for a 1 hour increase in time on the spent on the website:

Code
# Predict CATE of 1 hour increase
linear_dml_cates = DML_model.predict(
    df.assign(
        residualized_time_on_website=lambda x: x.residualized_time_on_website + 1
    )
) - DML_model.predict(df)

# True CATE of 1 hour increase
X = [age, num_social_media_profiles, yr_membership, Z, ε2]
true_cates = y(time_on_website + 1, *X) - y(time_on_website, *X)

print(f"Mean Squared Error: {mean_squared_error(true_cates, linear_dml_cates)}")
print(f"Mean Absolute Error: {mean_absolute_error(true_cates, linear_dml_cates)}")
print(f"R-Squared: {r2_score(true_cates, linear_dml_cates)}")
Mean Squared Error: 1.5107484209942401
Mean Absolute Error: 0.9878858217850499
R-Squared: 0.9825949051151978

Plotting the distributions of the predicted CATE and true CATE, we obtain:

Code
hist_effect(true_cates, linear_dml_cates, save_path="data/linear_dml_hist.webp")

Additionally, plotting the predicted values versus the true values we obtain:

Code
plot_effect(true_cates, linear_dml_cates, save_path="data/linear_dml_line.webp")

Overall, we have pretty impressive performance! However, the primary limitation in this approach is that we must manually specify the functional form of the CATE function, thus if we are only including linear interaction terms we may not capture the true CATE function. In our example, we simulated the DGP to only have these linear interaction terms and thus the performance is strong by construction, but let’s see what happens when we tweak the DGP for the CATE to be arbitrarily non-linear:

Code
# Outcome (y = f(T,X,Z) + ε2)
def y_fn_nonlinear(
    time_on_website, age, num_social_media_profiles, yr_membership, Z, ε2
):
    sales = np.maximum(
        25
        + 5 * time_on_website  # Baseline Treatment Effect
        - 0.2 * time_on_website * age  # Heterogeneity
        - 0.0005 * time_on_website * age**2  # Heterogeneity
        + 0.8 * time_on_website * num_social_media_profiles  # Heterogeneity
        + 0.001 * time_on_website * num_social_media_profiles**2  # Heterogeneity
        + 0.8 * time_on_website * yr_membership  # Heterogeneity
        + 0.001 * time_on_website * yr_membership**2  # Heterogeneity
        + 0.005
        * time_on_website
        * yr_membership
        * num_social_media_profiles
        * age  # Heterogeneity
        + 0.005
        * time_on_website
        * (yr_membership**3 / (1 + num_social_media_profiles**2))
        * np.log(age) ** 2
        - 0.1 * age
        - 0.001 * age**2
        + 8 * num_social_media_profiles
        - 0.1 * num_social_media_profiles**2
        - 0.01 * (age * num_social_media_profiles)
        + 2 * yr_membership
        + 0.1 * yr_membership**2
        - 0.01 * (age * yr_membership)
        + 3 * (num_social_media_profiles * yr_membership)
        + 0.1
        * (num_social_media_profiles * np.log(age) * age * yr_membership ** (1 / 2))
        + 0.5 * Z
        + ε2,
        0,
    )
    return sales

sales_nonlinear = y_fn_nonlinear(
    time_on_website, age, num_social_media_profiles, yr_membership, Z, ε2
)

df_nonlinear = pd.DataFrame(
    np.array(
        [
            sales_nonlinear,
            time_on_website,
            age,
            num_social_media_profiles,
            yr_membership,
            Z,
        ]
    ).T,
    columns=[
        "sales",
        "time_on_website",
        "age",
        "num_social_media_profiles",
        "yr_membership",
        "Z",
    ],
)

Fitting our models:

Code
# DML Procedure
M_sales2 = GradientBoostingRegressor()
M_time_on_website2 = GradientBoostingRegressor()

df_nonlinear["residualized_sales"] = df_nonlinear["sales"] - cross_val_predict(
    M_sales2,
    df_nonlinear[["age", "num_social_media_profiles", "yr_membership"]],
    df_nonlinear["sales"],
    cv=3,
)

df_nonlinear["residualized_time_on_website"] = df_nonlinear[
    "time_on_website"
] - cross_val_predict(
    M_time_on_website2,
    df_nonlinear[["age", "num_social_media_profiles", "yr_membership"]],
    df_nonlinear["time_on_website"],
    cv=3,
)

DML_model_nonlinear = smf.ols(
    formula="residualized_sales ~ 1 + residualized_time_on_website + residualized_time_on_website:age + residualized_time_on_website:num_social_media_profiles + residualized_time_on_website:yr_membership",
    data=df_nonlinear,
).fit()

print(DML_model_nonlinear.summary())
                            OLS Regression Results                            
==============================================================================
Dep. Variable:     residualized_sales   R-squared:                       0.496
Model:                            OLS   Adj. R-squared:                  0.496
Method:                 Least Squares   F-statistic:                 2.457e+04
Date:                Fri, 16 Jan 2026   Prob (F-statistic):               0.00
Time:                        14:53:58   Log-Likelihood:            -5.8008e+05
No. Observations:              100000   AIC:                         1.160e+06
Df Residuals:                   99995   BIC:                         1.160e+06
Df Model:                           4                                         
Covariance Type:            nonrobust                                         
==========================================================================================================================
                                                             coef    std err          t      P>|t|      [0.025      0.975]
--------------------------------------------------------------------------------------------------------------------------
Intercept                                                 -0.3735      0.253     -1.477      0.140      -0.869       0.122
residualized_time_on_website                              -2.9807      0.182    -16.389      0.000      -3.337      -2.624
residualized_time_on_website:age                          -0.0268      0.003     -9.162      0.000      -0.033      -0.021
residualized_time_on_website:num_social_media_profiles     0.5612      0.015     36.362      0.000       0.531       0.591
residualized_time_on_website:yr_membership                 2.7749      0.015    181.695      0.000       2.745       2.805
==============================================================================
Omnibus:                    10071.318   Durbin-Watson:                   2.003
Prob(Omnibus):                  0.000   Jarque-Bera (JB):            75935.442
Skew:                           0.165   Prob(JB):                         0.00
Kurtosis:                       7.256   Cond. No.                         257.
==============================================================================

Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

And then evaluating performance:

Code
# Predict CATE of 1 hour increase
linear_dml_cates_nonlinear = DML_model_nonlinear.predict(
    df_nonlinear.assign(
        residualized_time_on_website=lambda x: x.residualized_time_on_website + 1
    )
) - DML_model_nonlinear.predict(df_nonlinear)

# True CATE of 1 hour increase
true_cates_nonlinear = y_fn_nonlinear(time_on_website + 1, *X) - y_fn_nonlinear(
    time_on_website, *X
)

print(
    f"Mean Squared Error: {mean_squared_error(true_cates_nonlinear, linear_dml_cates_nonlinear)}"
)
print(
    f"Mean Absolute Error: {mean_absolute_error(true_cates_nonlinear, linear_dml_cates_nonlinear)}"
)
print(f"R-Squared: {r2_score(true_cates_nonlinear, linear_dml_cates_nonlinear)}")
Mean Squared Error: 53.704983242765145
Mean Absolute Error: 4.314548896807403
R-Squared: 0.6287019468647995
Code
hist_effect(
    true_cates_nonlinear,
    linear_dml_cates_nonlinear,
    save_path="data/linear_dml_nonlinear_hist.webp",
)

Code
plot_effect(
    true_cates_nonlinear,
    linear_dml_cates_nonlinear,
    save_path="data/linear_dml_nonlinear_line.webp",
)

Here we see much degradation in performance. This non-linearity in the CATE function is precisely where Non-Parametric DML can shine!

Non-Parametric DML for Estimating the CATE

Non-Parametric DML goes one step further and allows for another flexible non-parametric ML model to be utilized for learning the CATE function! Let’s take a look at how we can, mathematically, do exactly this. Let \(\tau(\mathbf{X})\) continue to denote our CATE function. Let’s start with defining our error term relative to eq. 3 (note we drop the intercept \(\beta_0\) as this parameter is partialled out in residualization step; we could similarly drop this in the linear DML formulation, but for the sake of simplicity and consistency with part 1, we do not do this):

\[ \begin{align*} y-\mathcal{M}_y(\mathbf{X})&=\tau(\mathbf{X})(T-\mathcal{M}_T(\mathbf{X}))+\epsilon \\ \tilde{y} &=\tau(\mathbf{X})\tilde{T}+\epsilon \\ \epsilon&=\tilde{y}-\tau(\mathbf{X})\tilde{T} \end{align*} \]

Then define the causal loss function as such (note this is just the MSE!):

\[ \begin{align*} \mathscr{L}(\tau(\mathbf{X})) &= \frac{1}{N}\sum_{i=1}^N\bigl(\tilde{y}_i - \tau(\mathbf{X}_i)\tilde{T}_i\bigr)^2 \\ &= \frac{1}{N}\sum_{i=1}^N\tilde{T}_i^2\bigl(\frac{\tilde{y}_i}{\tilde{T}_i} - \tau(\mathbf{X}_i)\bigr)^2 \end{align*} \]

What does this mean? We can directly learn \(\tau(\mathbf{X})\) with any flexible ML model via minimizing our causal loss function! This amounts to a weighted regression problem with our target and weights, respectively, as:

\[ \begin{align*} \text{Target}&=\frac{\tilde{y}_i}{\tilde{T}_i} \\ \text{Weights}&=\tilde{T}_i^2 \\ \end{align*} \]

Take a moment and soak in the elegance of this result… We can directly learn the CATE function & predict an individuals CATE given our residualized outcome, \(y\), and treatment, \(T\)!

Let’s take a look at this in action now. We will reuse the DGP for the non-linear CATE function that was utilized in the example where linear DML performs poorly above. To construct of Non-Parametric DML model, we can run:

Then define the causal loss function as such (note this is just the MSE!):

Code
# Define Target & Weights
df_nonlinear["target"] = (
    df_nonlinear["residualized_sales"]
    / df_nonlinear["residualized_time_on_website"]
)
df_nonlinear["weights"] = df_nonlinear["residualized_time_on_website"] ** 2

# Non-Parametric CATE Model
CATE_model = GradientBoostingRegressor()
CATE_model.fit(
    df_nonlinear[["age", "num_social_media_profiles", "yr_membership"]],
    df_nonlinear["target"],
    sample_weight=df_nonlinear["weights"],
)

And to make predictions + evaluate performance:

Code
# Predict CATE of 1 hour increase
nonparam_dml_cates_nonlinear = CATE_model.predict(
    df_nonlinear[["age", "num_social_media_profiles", "yr_membership"]]
)

print(
    f"Mean Squared Error: {mean_squared_error(true_cates_nonlinear, nonparam_dml_cates_nonlinear)}"
)
print(
    f"Mean Absolute Error: {mean_absolute_error(true_cates_nonlinear, nonparam_dml_cates_nonlinear)}"
)
print(f"R-Squared: {r2_score(true_cates_nonlinear, nonparam_dml_cates_nonlinear)}")
Mean Squared Error: 5.480215235810644
Mean Absolute Error: 1.6626763928716335
R-Squared: 0.9621116491439838
Code
hist_effect(
    true_cates_nonlinear,
    nonparam_dml_cates_nonlinear,
    save_path="data/nonparam_dml_nonlinear_hist.webp",
)

Code
plot_effect(
    true_cates_nonlinear,
    nonparam_dml_cates_nonlinear,
    save_path="data/nonparam_dml_nonlinear_line.webp",
)

Here we can see that, although not perfect, the non-parametric DML approach was able to model the non-linearities in the CATE function much better than the linear DML approach. We can of course further improve the performance via tuning our model. Note that we can use explainable AI tools, such as SHAP values, to understand the nature of our treatment effect heterogeneity!

Conclusion

And there you have it! Thank you for taking the time to read through my article. I hope this article has taught you how to go beyond estimating only the ATE & utilize DML to estimate the CATE to further understanding heterogeneity in the treatment effects and drive more causal inference- & data- driven targeting schemes.

As always, I hope you have enjoyed reading this as much as I enjoyed writing it!


❖❖❖

Access all the code via this GitHub Repo.

I appreciate you reading my post! My posts primarily explore real-world and theoretical applications of econometric and statistical/machine learning techniques, but also whatever I am currently interested in or learning 😁. At the end of the day, I write to learn! I hope to make complex topics slightly more accessible to all.