econml.policy.DRPolicyTree
- class econml.policy.DRPolicyTree(*, model_regression='auto', model_propensity='auto', featurizer=None, min_propensity=1e-06, categories='auto', cv=2, mc_iters=None, mc_agg='mean', max_depth=None, min_samples_split=10, min_samples_leaf=5, min_weight_fraction_leaf=0.0, max_features='auto', min_impurity_decrease=0.0, min_balancedness_tol=0.45, honest=True, random_state=None)[source]
Bases:
econml.policy._drlearner._BaseDRPolicyLearner
Policy learner that uses doubly-robust correction techniques to account for covariate shift (selection bias) between the treatment arms.
In this estimator, the policy is estimated by first constructing doubly robust estimates of the counterfactual outcomes
\[Y_{i, t}^{DR} = E[Y | X_i, W_i, T_i=t] + \frac{Y_i - E[Y | X_i, W_i, T_i=t]}{Pr[T_i=t | X_i, W_i]} \cdot 1\{T_i=t\}\]Then optimizing the objective
\[V(\pi) = \sum_i \sum_t \pi_t(X_i) * (Y_{i, t} - Y_{i, 0})\]with the constraint that only one of \(\pi_t(X_i)\) is 1 and the rest are 0, for each \(X_i\).
Thus if we estimate the nuisance functions \(h(X, W, T) = E[Y | X, W, T]\) and \(p_t(X, W)=Pr[T=t | X, W]\) in the first stage, we can estimate the final stage cate for each treatment t, by running a constructing a decision tree that maximizes the objective \(V(\pi)\)
The problem of estimating the nuisance function \(p\) is a simple multi-class classification problem of predicting the label \(T\) from \(X, W\). The
DRLearner
class takes as input the parametermodel_propensity
, which is an arbitrary scikit-learn classifier, that is internally used to solve this classification problem.The second nuisance function \(h\) is a simple regression problem and the
DRLearner
class takes as input the parametermodel_regressor
, which is an arbitrary scikit-learn regressor that is internally used to solve this regression problem.- Parameters
model_propensity (scikit-learn classifier or ‘auto’, default ‘auto’) – Estimator for Pr[T=t | X, W]. Trained by regressing treatments on (features, controls) concatenated. Must implement fit and predict_proba methods. The fit method must be able to accept X and T, where T is a shape (n, ) array. If ‘auto’,
LogisticRegressionCV
will be chosen.model_regression (scikit-learn regressor or ‘auto’, default ‘auto’) – Estimator for E[Y | X, W, T]. Trained by regressing Y on (features, controls, one-hot-encoded treatments) concatenated. The one-hot-encoding excludes the baseline treatment. Must implement fit and predict methods. If different models per treatment arm are desired, see the
MultiModelWrapper
helper class. If ‘auto’WeightedLassoCV
/WeightedMultiTaskLassoCV
will be chosen.featurizer (transformer, optional) – Must support fit_transform and transform. Used to create composite features in the final CATE regression. It is ignored if X is None. The final CATE will be trained on the outcome of featurizer.fit_transform(X). If featurizer=None, then CATE is trained on X.
min_propensity (float, default
1e-6
) – The minimum propensity at which to clip propensity estimates to avoid dividing by zero.categories (‘auto’ or list, default ‘auto’) – The categories to use when encoding discrete treatments (or ‘auto’ to use the unique sorted values). The first category will be treated as the control treatment.
cv (int, cross-validation generator or an iterable, default 2) – Determines the cross-validation splitting strategy. Possible inputs for cv are:
None, to use the default 3-fold cross-validation,
integer, to specify the number of folds.
An iterable yielding (train, test) splits as arrays of indices.
For integer/None inputs, if the treatment is discrete
StratifiedKFold
is used, else,KFold
is used (with a random shuffle in either case).Unless an iterable is used, we call split(concat[W, X], T) to generate the splits. If all W, X are None, then we call split(ones((T.shape[0], 1)), T).
mc_iters (int, optional) – The number of times to rerun the first stage models to reduce the variance of the nuisances.
mc_agg ({‘mean’, ‘median’}, default ‘mean’) – How to aggregate the nuisance value for each sample across the mc_iters monte carlo iterations of cross-fitting.
max_depth (int or None, optional) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
min_samples_split (int, float, default 10) – The minimum number of splitting samples required to split an internal node.
If int, then consider min_samples_split as the minimum number.
If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.
min_samples_leaf (int, float, default 5) – The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least
min_samples_leaf
splitting samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression. After construction the tree is also pruned so that there are at least min_samples_leaf estimation samples on each leaf.If int, then consider min_samples_leaf as the minimum number.
If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.
min_weight_fraction_leaf (float, default 0.) – The minimum weighted fraction of the sum total of weights (of all splitting samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided. After construction the tree is pruned so that the fraction of the sum total weight of the estimation samples contained in each leaf node is at least min_weight_fraction_leaf
max_features (int, float, str, or None, default “auto”) – The number of features to consider when looking for the best split:
If int, then consider max_features features at each split.
If float, then max_features is a fraction and int(max_features * n_features) features are considered at each split.
If “auto”, then max_features=n_features.
If “sqrt”, then max_features=sqrt(n_features).
If “log2”, then max_features=log2(n_features).
If None, then max_features=n_features.
Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than
max_features
features.min_impurity_decrease (float, default 0.) – A node will be split if this split induces a decrease of the impurity greater than or equal to this value.
The weighted impurity decrease equation is the following:
N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity)
where
N
is the total number of split samples,N_t
is the number of split samples at the current node,N_t_L
is the number of split samples in the left child, andN_t_R
is the number of split samples in the right child.N
,N_t
,N_t_R
andN_t_L
all refer to the weighted sum, ifsample_weight
is passed.min_balancedness_tol (float in [0, .5], default .45) – How imbalanced a split we can tolerate. This enforces that each split leaves at least (.5 - min_balancedness_tol) fraction of samples on each side of the split; or fraction of the total weight of samples, when sample_weight is not None. Default value, ensures that at least 5% of the parent node weight falls in each side of the split. Set it to 0.0 for no balancedness and to .5 for perfectly balanced splits. For the formal inference theory to be valid, this has to be any positive constant bounded away from zero.
honest (bool, default True) – Whether to use honest trees, i.e. half of the samples are used for creating the tree structure and the other half for the estimation at the leafs. If False, then all samples are used for both parts.
random_state (int, RandomState instance, or None, default None) – If int, random_state is the seed used by the random number generator; If
RandomState
instance, random_state is the random number generator; If None, the random number generator is theRandomState
instance used bynp.random
.
- __init__(*, model_regression='auto', model_propensity='auto', featurizer=None, min_propensity=1e-06, categories='auto', cv=2, mc_iters=None, mc_agg='mean', max_depth=None, min_samples_split=10, min_samples_leaf=5, min_weight_fraction_leaf=0.0, max_features='auto', min_impurity_decrease=0.0, min_balancedness_tol=0.45, honest=True, random_state=None)[source]
Methods
__init__
(*[, model_regression, ...])export_graphviz
(*[, out_file, ...])Export a graphviz dot file representing the learned tree model
feature_importances
([max_depth, ...])- Parameters
max_depth (int, default 4) -- Splits of depth larger than max_depth are not used in this calculation
fit
(Y, T, *[, X, W, sample_weight, groups])Estimate a policy model from data.
plot
(*[, feature_names, treatment_names, ...])Exports policy trees to matplotlib
policy_feature_names
(*[, feature_names])Get the output feature names.
policy_treatment_names
(*[, treatment_names])Get the names of the treatments.
predict
(X)Get recommended treatment for each sample.
Predict the probability of recommending each treatment
Get effect values for each non-baseline treatment and for each sample.
render
(out_file, *[, format, view, ...])Render the tree to a flie
Attributes
feature_importances_
The trained final stage policy model
- export_graphviz(*, out_file=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)[source]
Export a graphviz dot file representing the learned tree model
- Parameters
out_file (file object or str, optional) – Handle or name of the output file. If
None
, the result is returned as a string.feature_names (list of str, optional) – Names of each of the features.
treatment_names (list of str, optional) – Names of each of the treatments, including the baseline treatment
max_depth (int or None, optional) – The maximum tree depth to plot
filled (bool, default False) – When set to
True
, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.leaves_parallel (bool, default True) – When set to
True
, draw all leaf nodes at the bottom of the tree.rotate (bool, default False) – When set to
True
, orient tree left to right rather than top-down.rounded (bool, default True) – When set to
True
, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.special_characters (bool, default False) – When set to
False
, ignore special characters for PostScript compatibility.precision (int, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.
- feature_importances(max_depth=4, depth_decay_exponent=2.0)
- Parameters
max_depth (int, default 4) – Splits of depth larger than max_depth are not used in this calculation
depth_decay_exponent (double, default 2.0) – The contribution of each split to the total score is re-weighted by
1 / (1 + `depth`)**2.0
.
- Returns
feature_importances_ – Normalized total parameter heterogeneity inducing importance of each feature
- Return type
ndarray of shape (n_features,)
- fit(Y, T, *, X=None, W=None, sample_weight=None, groups=None)
Estimate a policy model from data.
- Parameters
Y ((n,) vector of length n) – Outcomes for each sample
T ((n,) vector of length n) – Treatments for each sample
X ((n, d_x) matrix, optional) – Features for each sample
W ((n, d_w) matrix, optional) – Controls for each sample
sample_weight ((n,) vector, optional) – Weights for each samples
groups ((n,) vector, optional) – All rows corresponding to the same group will be kept together during splitting. If groups is not None, the cv argument passed to this class’s initializer must support a ‘groups’ argument to its split method.
- Returns
self
- Return type
object instance
- plot(*, feature_names=None, treatment_names=None, ax=None, title=None, max_depth=None, filled=True, rounded=True, precision=3, fontsize=None)[source]
Exports policy trees to matplotlib
- Parameters
ax (
matplotlib.axes.Axes
, optional) – The axes on which to plottitle (str, optional) – A title for the final figure to be printed at the top of the page.
feature_names (list of str, optional) – Names of each of the features.
treatment_names (list of str, optional) – Names of each of the treatments including the baseline/control
max_depth (int or None, optional) – The maximum tree depth to plot
filled (bool, default False) – When set to
True
, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.rounded (bool, default True) – When set to
True
, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.precision (int, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.
fontsize (int, optional) – Font size for text
- policy_feature_names(*, feature_names=None)
Get the output feature names.
- Parameters
feature_names (list of str of length X.shape[1] or None) – The names of the input features. If None and X is a dataframe, it defaults to the column names from the dataframe.
- Returns
out_feature_names – The names of the output features on which the policy model is fitted.
- Return type
list of str or None
- policy_treatment_names(*, treatment_names=None)
Get the names of the treatments.
- Parameters
treatment_names (list of str of length n_categories) – The names of the treatments (including the baseling). If None then values are auto-generated based on input metadata.
- Returns
out_treatment_names – The names of the treatments including the baseline/control treatment.
- Return type
list of str
- predict(X)
Get recommended treatment for each sample.
- Parameters
X (array_like of shape (n_samples, n_features)) – The training input samples.
- Returns
treatment – The index of the recommended treatment in the same order as in categories, or in lexicographic order if categories=’auto’. 0 corresponds to the baseline/control treatment. For ensemble policy models, recommended treatments are aggregated from each model in the ensemble and the treatment that receives the most votes is returned. Use predict_proba to get the fraction of models in the ensemble that recommend each treatment for each sample.
- Return type
array_like of shape (n_samples,)
- predict_proba(X)
Predict the probability of recommending each treatment
- Parameters
X (array_like of shape (n_samples, n_features)) – The input samples.
- Returns
treatment_proba – The probability of each treatment recommendation
- Return type
array_like of shape (n_samples, n_treatments)
- predict_value(X)
Get effect values for each non-baseline treatment and for each sample.
- Parameters
X (array_like of shape (n_samples, n_features)) – The training input samples.
- Returns
values – The predicted average value for each sample and for each non-baseline treatment, as compared to the baseline treatment value and based on the feature neighborhoods defined by the trees.
- Return type
array_like of shape (n_samples, n_treatments - 1)
- render(out_file, *, format='pdf', view=True, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)[source]
Render the tree to a flie
- Parameters
out_file (file name to save to)
format (str, default ‘pdf’) – The file format to render to; must be supported by graphviz
view (bool, default True) – Whether to open the rendered result with the default application.
feature_names (list of str, optional) – Names of each of the features.
treatment_names (list of str, optional) – Names of each of the treatments, including the baseline/control
max_depth (int or None, optional) – The maximum tree depth to plot
filled (bool, default False) – When set to
True
, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.leaves_parallel (bool, default True) – When set to
True
, draw all leaf nodes at the bottom of the tree.rotate (bool, default False) – When set to
True
, orient tree left to right rather than top-down.rounded (bool, default True) – When set to
True
, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.special_characters (bool, default False) – When set to
False
, ignore special characters for PostScript compatibility.precision (int, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.
- property policy_model_
The trained final stage policy model