econml.orf.DMLOrthoForest
- class econml.orf.DMLOrthoForest(*, n_trees=500, min_leaf_size=10, max_depth=10, subsample_ratio=0.7, bootstrap=False, lambda_reg=0.01, model_T='auto', model_Y=<econml.sklearn_extensions.linear_model.WeightedLassoCVWrapper object>, model_T_final=None, model_Y_final=None, global_residualization=False, global_res_cv=2, discrete_treatment=False, treatment_featurizer=None, categories='auto', n_jobs=-1, backend='loky', verbose=3, batch_size='auto', random_state=None, allow_missing=False)[source]
Bases:
econml.orf._ortho_forest.BaseOrthoForest
OrthoForest for continuous or discrete treatments using the DML residual on residual moment function.
A two-forest approach for learning heterogeneous treatment effects using kernel two stage estimation.
- Parameters
n_trees (int, default 500) – Number of causal estimators in the forest.
min_leaf_size (int, default 10) – The minimum number of samples in a leaf.
max_depth (int, default 10) – The maximum number of splits to be performed when expanding the tree.
subsample_ratio (float, default 0.7) – The ratio of the total sample to be used when training a causal tree. Values greater than 1.0 will be considered equal to 1.0. Parameter is ignored when bootstrap=True.
bootstrap (bool, default False) – Whether to use bootstrap subsampling.
lambda_reg (float, default 0.01) – The regularization coefficient in the ell_2 penalty imposed on the locally linear part of the second stage fit. This is not applied to the local intercept, only to the coefficient of the linear component.
model_T (estimator, default sklearn.linear_model.LassoCV(cv=3)) – The estimator for residualizing the continuous treatment at each leaf. Must implement fit and predict methods.
model_Y (estimator, default sklearn.linear_model.LassoCV(cv=3)) – The estimator for residualizing the outcome at each leaf. Must implement fit and predict methods.
model_T_final (estimator, optional) – The estimator for residualizing the treatment at prediction time. Must implement fit and predict methods. If parameter is set to
None
, it defaults to the value of model_T parameter.model_Y_final (estimator, optional) – The estimator for residualizing the outcome at prediction time. Must implement fit and predict methods. If parameter is set to
None
, it defaults to the value of model_Y parameter.global_residualization (bool, default False) – Whether to perform a prior residualization of Y and T using the model_Y_final and model_T_final estimators, or whether to perform locally weighted residualization at each target point. Global residualization is computationally less intensive, but could lose some statistical power, especially when W is not None.
global_res_cv (int, cross-validation generator or an iterable, default 2) – The specification of the CV splitter to be used for cross-fitting, when constructing the global residuals of Y and T.
discrete_treatment (bool, default False) – Whether the treatment should be treated as categorical. If True, then the treatment T is one-hot-encoded and the model_T is treated as a classifier that must have a predict_proba method.
treatment_featurizer (transformer, optional) – Must support fit_transform and transform. Used to create composite treatment in the final CATE regression. The final CATE will be trained on the outcome of featurizer.fit_transform(T). If featurizer=None, then CATE is trained on T.
categories (array_like or ‘auto’, default ‘auto’) – A list of pre-specified treatment categories. If ‘auto’ then categories are automatically recognized at fit time.
n_jobs (int, default -1) – The number of jobs to run in parallel for both
fit()
andeffect()
.-1
means using all processors. Since OrthoForest methods are computationally heavy, it is recommended to set n_jobs to -1.backend (‘threading’ or ‘loky’, default ‘loky’) – What backend should be used for parallelization with the joblib library.
verbose (int, default 3) – Verbosity level
batch_size (int or ‘auto’, default ‘auto’) – Batch_size of jobs for parallelism
random_state (int, RandomState instance, or None, default None) – If int, random_state is the seed used by the random number generator; If
RandomState
instance, random_state is the random number generator; If None, the random number generator is theRandomState
instance used bynp.random
.allow_missing (bool) – Whether to allow missing values in W. If True, will need to supply nuisance_models that can handle missing values.
- __init__(*, n_trees=500, min_leaf_size=10, max_depth=10, subsample_ratio=0.7, bootstrap=False, lambda_reg=0.01, model_T='auto', model_Y=<econml.sklearn_extensions.linear_model.WeightedLassoCVWrapper object>, model_T_final=None, model_Y_final=None, global_residualization=False, global_res_cv=2, discrete_treatment=False, treatment_featurizer=None, categories='auto', n_jobs=-1, backend='loky', verbose=3, batch_size='auto', random_state=None, allow_missing=False)[source]
Methods
__init__
(*[, n_trees, min_leaf_size, ...])ate
([X, T0, T1])Calculate the average treatment effect \(E_X[\tau(X, T0, T1)]\).
ate_inference
([X, T0, T1])Inference results for the quantity \(E_X[\tau(X, T0, T1)]\) produced by the model.
ate_interval
([X, T0, T1, alpha])Confidence intervals for the quantity \(E_X[\tau(X, T0, T1)]\) produced by the model.
cate_feature_names
([feature_names])Public interface for getting feature names.
cate_output_names
([output_names])Public interface for getting output names.
cate_treatment_names
([treatment_names])Get treatment names.
const_marginal_ate
([X])Calculate the average constant marginal CATE \(E_X[\theta(X)]\).
Inference results for the quantities \(E_X[\theta(X)]\) produced by the model.
const_marginal_ate_interval
([X, alpha])Confidence intervals for the quantities \(E_X[\theta(X)]\) produced by the model.
Calculate the constant marginal CATE θ(·) conditional on a vector of features X.
Inference results for the quantities \(\theta(X)\) produced by the model.
const_marginal_effect_interval
([X, alpha])Confidence intervals for the quantities \(\theta(X)\) produced by the model.
effect
([X, T0, T1])Calculate the heterogeneous treatment effect \(\tau(X, T0, T1)\).
effect_inference
([X, T0, T1])Inference results for the quantities \(\tau(X, T0, T1)\) produced by the model.
effect_interval
([X, T0, T1, alpha])Confidence intervals for the quantities \(\tau(X, T0, T1)\) produced by the model.
fit
(Y, T, *, X[, W, inference])Build an orthogonal random forest from a training set (Y, T, X, W).
marginal_ate
(T[, X])Calculate the average marginal effect \(E_{T, X}[\partial\tau(T, X)]\).
marginal_ate_inference
(T[, X])Inference results for the quantities \(E_{T,X}[\partial \tau(T, X)]\) produced by the model.
marginal_ate_interval
(T[, X, alpha])Confidence intervals for the quantities \(E_{T,X}[\partial \tau(T, X)]\) produced by the model.
marginal_effect
(T[, X])Calculate the heterogeneous marginal effect \(\partial\tau(T, X)\).
marginal_effect_inference
(T[, X])Inference results for the quantities \(\partial \tau(T, X)\) produced by the model.
marginal_effect_interval
(T[, X, alpha])Confidence intervals for the quantities \(\partial \tau(T, X)\) produced by the model.
shap_values
(X, *[, feature_names, ...])Shap value for the final stage models (const_marginal_effect)
Attributes
Get an instance of
DoWhyWrapper
to allow other functionalities from dowhy package.transformer
- ate(X=None, *, T0=0, T1=1)
Calculate the average treatment effect \(E_X[\tau(X, T0, T1)]\).
The effect is calculated between the two treatment points and is averaged over the population of X variables.
- Parameters
T0 ((m, d_t) matrix or vector of length m) – Base treatments for each sample
T1 ((m, d_t) matrix or vector of length m) – Target treatments for each sample
X ((m, d_x) matrix, optional) – Features for each sample
- Returns
τ – Average treatment effects on each outcome Note that when Y is a vector rather than a 2-dimensional array, the result will be a scalar
- Return type
float or (d_y,) array
- ate_inference(X=None, *, T0=0, T1=1)
Inference results for the quantity \(E_X[\tau(X, T0, T1)]\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
X ((m, d_x) matrix, optional) – Features for each sample
T0 ((m, d_t) matrix or vector of length m, default 0) – Base treatments for each sample
T1 ((m, d_t) matrix or vector of length m, default 1) – Target treatments for each sample
- Returns
PopulationSummaryResults – The inference results instance contains prediction and prediction standard error and can on demand calculate confidence interval, z statistic and p value. It can also output a dataframe summary of these inference results.
- Return type
- ate_interval(X=None, *, T0=0, T1=1, alpha=0.05)
Confidence intervals for the quantity \(E_X[\tau(X, T0, T1)]\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
X ((m, d_x) matrix, optional) – Features for each sample
T0 ((m, d_t) matrix or vector of length m, default 0) – Base treatments for each sample
T1 ((m, d_t) matrix or vector of length m, default 1) – Target treatments for each sample
alpha (float in [0, 1], default 0.05) – The overall level of confidence of the reported interval. The alpha/2, 1-alpha/2 confidence interval is reported.
- Returns
lower, upper – The lower and the upper bounds of the confidence interval for each quantity.
- Return type
tuple(type of
ate(X, T0, T1)
, type ofate(X, T0, T1))
)
- cate_feature_names(feature_names=None)
Public interface for getting feature names.
To be overriden by estimators that apply transformations the input features.
- Parameters
feature_names (list of str of length X.shape[1] or None) – The names of the input features. If None and X is a dataframe, it defaults to the column names from the dataframe.
- Returns
out_feature_names – Returns feature names.
- Return type
list of str or None
- cate_output_names(output_names=None)
Public interface for getting output names.
To be overriden by estimators that apply transformations the outputs.
- Parameters
output_names (list of str of length Y.shape[1] or None) – The names of the outcomes. If None and the Y passed to fit was a dataframe, it defaults to the column names from the dataframe.
- Returns
output_names – Returns output names.
- Return type
list of str
- cate_treatment_names(treatment_names=None)
Get treatment names.
If the treatment is discrete or featurized, it will return expanded treatment names.
- Parameters
treatment_names (list of str of length T.shape[1], optional) – The names of the treatments. If None and the T passed to fit was a dataframe, it defaults to the column names from the dataframe.
- Returns
out_treatment_names – Returns (possibly expanded) treatment names.
- Return type
list of str
- const_marginal_ate(X=None)
Calculate the average constant marginal CATE \(E_X[\theta(X)]\).
- Parameters
X ((m, d_x) matrix, optional) – Features for each sample.
- Returns
theta – Average constant marginal CATE of each treatment on each outcome. Note that when Y or featurized-T (or T if treatment_featurizer is None) is a vector rather than a 2-dimensional array, the corresponding singleton dimensions in the output will be collapsed (e.g. if both are vectors, then the output of this method will also be a scalar)
- Return type
(d_y, d_f_t) matrix where d_f_t is the dimension of the featurized treatment. If treatment_featurizer is None, d_f_t = d_t.
- const_marginal_ate_inference(X=None)
Inference results for the quantities \(E_X[\theta(X)]\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
X ((m, d_x) matrix, optional) – Features for each sample
- Returns
PopulationSummaryResults – The inference results instance contains prediction and prediction standard error and can on demand calculate confidence interval, z statistic and p value. It can also output a dataframe summary of these inference results.
- Return type
- const_marginal_ate_interval(X=None, *, alpha=0.05)
Confidence intervals for the quantities \(E_X[\theta(X)]\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
X ((m, d_x) matrix, optional) – Features for each sample
alpha (float in [0, 1], default 0.05) – The overall level of confidence of the reported interval. The alpha/2, 1-alpha/2 confidence interval is reported.
- Returns
lower, upper – The lower and the upper bounds of the confidence interval for each quantity.
- Return type
tuple(type of
const_marginal_ate(X)
, type ofconst_marginal_ate(X)
)
- const_marginal_effect(X)[source]
Calculate the constant marginal CATE θ(·) conditional on a vector of features X.
- Parameters
X (array_like, shape (n, d_x)) – Feature vector that captures heterogeneity.
- Returns
Theta – Constant marginal CATE of each treatment for each sample.
- Return type
matrix , shape (n, d_f_t) where d_f_t is the dimension of the featurized treatment. If treatment_featurizer is None, d_f_t = d_t
- const_marginal_effect_inference(X=None)
Inference results for the quantities \(\theta(X)\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
X ((m, d_x) matrix, optional) – Features for each sample
- Returns
InferenceResults – The inference results instance contains prediction and prediction standard error and can on demand calculate confidence interval, z statistic and p value. It can also output a dataframe summary of these inference results.
- Return type
- const_marginal_effect_interval(X=None, *, alpha=0.05)
Confidence intervals for the quantities \(\theta(X)\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
X ((m, d_x) matrix, optional) – Features for each sample
alpha (float in [0, 1], default 0.05) – The overall level of confidence of the reported interval. The alpha/2, 1-alpha/2 confidence interval is reported.
- Returns
lower, upper – The lower and the upper bounds of the confidence interval for each quantity.
- Return type
tuple(type of
const_marginal_effect(X)
, type ofconst_marginal_effect(X)
)
- effect(X=None, *, T0=0, T1=1)
Calculate the heterogeneous treatment effect \(\tau(X, T0, T1)\).
The effect is calculated between the two treatment points conditional on a vector of features on a set of m test samples \(\{T0_i, T1_i, X_i\}\).
- Parameters
T0 ((m, d_t) matrix or vector of length m) – Base treatments for each sample
T1 ((m, d_t) matrix or vector of length m) – Target treatments for each sample
X ((m, d_x) matrix, optional) – Features for each sample
- Returns
τ – Heterogeneous treatment effects on each outcome for each sample Note that when Y is a vector rather than a 2-dimensional array, the corresponding singleton dimension will be collapsed (so this method will return a vector)
- Return type
(m, d_y) matrix
- effect_inference(X=None, *, T0=0, T1=1)
Inference results for the quantities \(\tau(X, T0, T1)\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
X ((m, d_x) matrix, optional) – Features for each sample
T0 ((m, d_t) matrix or vector of length m, default 0) – Base treatments for each sample
T1 ((m, d_t) matrix or vector of length m, default 1) – Target treatments for each sample
- Returns
InferenceResults – The inference results instance contains prediction and prediction standard error and can on demand calculate confidence interval, z statistic and p value. It can also output a dataframe summary of these inference results.
- Return type
- effect_interval(X=None, *, T0=0, T1=1, alpha=0.05)
Confidence intervals for the quantities \(\tau(X, T0, T1)\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
X ((m, d_x) matrix, optional) – Features for each sample
T0 ((m, d_t) matrix or vector of length m, default 0) – Base treatments for each sample
T1 ((m, d_t) matrix or vector of length m, default 1) – Target treatments for each sample
alpha (float in [0, 1], default 0.05) – The overall level of confidence of the reported interval. The alpha/2, 1-alpha/2 confidence interval is reported.
- Returns
lower, upper – The lower and the upper bounds of the confidence interval for each quantity.
- Return type
tuple(type of
effect(X, T0, T1)
, type ofeffect(X, T0, T1))
)
- fit(Y, T, *, X, W=None, inference='auto')[source]
Build an orthogonal random forest from a training set (Y, T, X, W).
- Parameters
Y (array_like, shape (n, )) – Outcome for the treatment policy.
T (array_like, shape (n, d_t)) – Treatment policy.
X (array_like, shape (n, d_x)) – Feature vector that captures heterogeneity.
W (array_like, shape (n, d_w), optional) – High-dimensional controls.
inference (str,
Inference
instance, or None) – Method for performing inference. This estimator supports ‘bootstrap’ (or an instance ofBootstrapInference
) and ‘blb’ (or an instance ofBLBInference
)
- Returns
self
- Return type
an instance of self.
- marginal_ate(T, X=None)
Calculate the average marginal effect \(E_{T, X}[\partial\tau(T, X)]\).
The marginal effect is calculated around a base treatment point and averaged over the population of X.
- Parameters
T ((m, d_t) matrix) – Base treatments for each sample
X ((m, d_x) matrix, optional) – Features for each sample
- Returns
grad_tau – Average marginal effects on each outcome Note that when Y or T is a vector rather than a 2-dimensional array, the corresponding singleton dimensions in the output will be collapsed (e.g. if both are vectors, then the output of this method will be a scalar)
- Return type
(d_y, d_t) array
- marginal_ate_inference(T, X=None)
Inference results for the quantities \(E_{T,X}[\partial \tau(T, X)]\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
T ((m, d_t) matrix) – Base treatments for each sample
X ((m, d_x) matrix, optional) – Features for each sample
- Returns
PopulationSummaryResults – The inference results instance contains prediction and prediction standard error and can on demand calculate confidence interval, z statistic and p value. It can also output a dataframe summary of these inference results.
- Return type
- marginal_ate_interval(T, X=None, *, alpha=0.05)
Confidence intervals for the quantities \(E_{T,X}[\partial \tau(T, X)]\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
T ((m, d_t) matrix) – Base treatments for each sample
X ((m, d_x) matrix, optional) – Features for each sample
alpha (float in [0, 1], default 0.05) – The overall level of confidence of the reported interval. The alpha/2, 1-alpha/2 confidence interval is reported.
- Returns
lower, upper – The lower and the upper bounds of the confidence interval for each quantity.
- Return type
tuple(type of
marginal_ate(T, X)
, type ofmarginal_ate(T, X)
)
- marginal_effect(T, X=None)
Calculate the heterogeneous marginal effect \(\partial\tau(T, X)\).
The marginal effect is calculated around a base treatment point conditional on a vector of features on a set of m test samples \(\{T_i, X_i\}\). If treatment_featurizer is None, the base treatment is ignored in this calculation and the result is equivalent to const_marginal_effect.
- Parameters
T ((m, d_t) matrix) – Base treatments for each sample
X ((m, d_x) matrix, optional) – Features for each sample
- Returns
grad_tau – Heterogeneous marginal effects on each outcome for each sample Note that when Y or T is a vector rather than a 2-dimensional array, the corresponding singleton dimensions in the output will be collapsed (e.g. if both are vectors, then the output of this method will also be a vector)
- Return type
(m, d_y, d_t) array
- marginal_effect_inference(T, X=None)
Inference results for the quantities \(\partial \tau(T, X)\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
T ((m, d_t) matrix) – Base treatments for each sample
X ((m, d_x) matrix, optional) – Features for each sample
- Returns
InferenceResults – The inference results instance contains prediction and prediction standard error and can on demand calculate confidence interval, z statistic and p value. It can also output a dataframe summary of these inference results.
- Return type
- marginal_effect_interval(T, X=None, *, alpha=0.05)
Confidence intervals for the quantities \(\partial \tau(T, X)\) produced by the model. Available only when
inference
is notNone
, when calling the fit method.- Parameters
T ((m, d_t) matrix) – Base treatments for each sample
X ((m, d_x) matrix, optional) – Features for each sample
alpha (float in [0, 1], default 0.05) – The overall level of confidence of the reported interval. The alpha/2, 1-alpha/2 confidence interval is reported.
- Returns
lower, upper – The lower and the upper bounds of the confidence interval for each quantity.
- Return type
tuple(type of
marginal_effect(T, X)
, type ofmarginal_effect(T, X)
)
- shap_values(X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100)
Shap value for the final stage models (const_marginal_effect)
- Parameters
X ((m, d_x) matrix) – Features for each sample. Should be in the same shape of fitted X in final stage.
feature_names (list of str of length X.shape[1], optional) – The names of input features.
treatment_names (list, optional) – The name of featurized treatment. In discrete treatment scenario, the name should not include the name of the baseline treatment (i.e. the control treatment, which by default is the alphabetically smaller)
output_names (list, optional) – The name of the outcome.
background_samples (int , default 100) – How many samples to use to compute the baseline effect. If None then all samples are used.
- Returns
shap_outs – A nested dictionary by using each output name (e.g. ‘Y0’, ‘Y1’, … when output_names=None) and each treatment name (e.g. ‘T0’, ‘T1’, … when treatment_names=None) as key and the shap_values explanation object as value. If the input data at fit time also contain metadata, (e.g. are pandas DataFrames), then the column metatdata for the treatments, outcomes and features are used instead of the above defaults (unless the user overrides with explicitly passing the corresponding names).
- Return type
nested dictionary of Explanation object
- property dowhy
Get an instance of
DoWhyWrapper
to allow other functionalities from dowhy package. (e.g. causal graph, refutation test, etc.)- Returns
DoWhyWrapper – An instance of
DoWhyWrapper
- Return type
instance