econml.cate_interpreter.SingleTreeCateInterpreter
- class econml.cate_interpreter.SingleTreeCateInterpreter(*, include_model_uncertainty=False, uncertainty_level=0.05, uncertainty_only_on_leaves=True, splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0)[source]
Bases:
econml.cate_interpreter._interpreters._SingleTreeInterpreter
An interpreter for the effect estimated by a CATE estimator
- Parameters
include_model_uncertainty (bool, default False) – Whether to include confidence interval information when building a simplified model of the cate model. If set to True, then cate estimator needs to support the const_marginal_ate_inference method.
uncertainty_level (double, default 0.05) – The uncertainty level for the confidence intervals to be constructed and used in the simplified model creation. If value=alpha then a multitask decision tree will be built such that all samples in a leaf have similar target prediction but also similar alpha confidence intervals.
uncertainty_only_on_leaves (bool, default True) – Whether uncertainty information should be displayed only on leaf nodes. If False, then interpretation can be slightly slower, especially for cate models that have a computationally expensive inference method.
splitter (str, default “best”) – The strategy used to choose the split at each node. Supported strategies are “best” to choose the best split and “random” to choose the best random split.
max_depth (int, optional) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.
min_samples_split (int, float, default 2) – The minimum number of samples required to split an internal node:
If int, then consider min_samples_split as the minimum number.
If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.
min_samples_leaf (int, float, default 1) – The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least
min_samples_leaf
training samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression.If int, then consider min_samples_leaf as the minimum number.
If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.
min_weight_fraction_leaf (float, default 0.) – The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided.
max_features (int, float, {“auto”, “sqrt”, “log2”}, or None, default None) – The number of features to consider when looking for the best split:
If int, then consider max_features features at each split.
If float, then max_features is a fraction and int(max_features * n_features) features are considered at each split.
If “auto”, then max_features=n_features.
If “sqrt”, then max_features=sqrt(n_features).
If “log2”, then max_features=log2(n_features).
If None, then max_features=n_features.
Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than
max_features
features.random_state (int, RandomState instance, or None, default None) – If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.
max_leaf_nodes (int, optional) – Grow a tree with
max_leaf_nodes
in best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes.min_impurity_decrease (float, default 0.) – A node will be split if this split induces a decrease of the impurity greater than or equal to this value.
The weighted impurity decrease equation is the following:
N_t / N * (impurity - N_t_R / N_t * right_impurity - N_t_L / N_t * left_impurity)
where
N
is the total number of samples,N_t
is the number of samples at the current node,N_t_L
is the number of samples in the left child, andN_t_R
is the number of samples in the right child.N
,N_t
,N_t_R
andN_t_L
all refer to the weighted sum, ifsample_weight
is passed.
- __init__(*, include_model_uncertainty=False, uncertainty_level=0.05, uncertainty_only_on_leaves=True, splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0)[source]
Methods
__init__
(*[, include_model_uncertainty, ...])export_graphviz
([out_file, feature_names, ...])Export a graphviz dot file representing the learned tree model
interpret
(cate_estimator, X)Interpret the heterogeneity of a CATE estimator when applied to a set of features
plot
([ax, title, feature_names, ...])Exports policy trees to matplotlib
render
(out_file[, format, view, ...])Render the tree to a flie
Attributes
node_dict_
tree_model_
- export_graphviz(out_file=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)
Export a graphviz dot file representing the learned tree model
- Parameters
out_file (file object or str, optional) – Handle or name of the output file. If
None
, the result is returned as a string.feature_names (list of str, optional) – Names of each of the features.
treatment_names (list of str, optional) – Names of each of the treatments
max_depth (int, optional) – The maximum tree depth to plot
filled (bool, default False) – When set to
True
, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.leaves_parallel (bool, default True) – When set to
True
, draw all leaf nodes at the bottom of the tree.rotate (bool, default False) – When set to
True
, orient tree left to right rather than top-down.rounded (bool, default True) – When set to
True
, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.special_characters (bool, default False) – When set to
False
, ignore special characters for PostScript compatibility.precision (int, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.
- interpret(cate_estimator, X)[source]
Interpret the heterogeneity of a CATE estimator when applied to a set of features
- Parameters
cate_estimator (
LinearCateEstimator
) – The fitted estimator to interpretX (array_like) – The features against which to interpret the estimator; must be compatible shape-wise with the features used to fit the estimator
- Returns
self
- Return type
object instance
- plot(ax=None, title=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, rounded=True, precision=3, fontsize=None)
Exports policy trees to matplotlib
- Parameters
ax (
matplotlib.axes.Axes
, optional) – The axes on which to plottitle (str, optional) – A title for the final figure to be printed at the top of the page.
feature_names (list of str, optional) – Names of each of the features.
treatment_names (list of str, optional) – Names of each of the treatments
max_depth (int, optional) – The maximum tree depth to plot
filled (bool, default False) – When set to
True
, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.rounded (bool, default True) – When set to
True
, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.precision (int, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.
fontsize (int, optional) – Font size for text
- render(out_file, format='pdf', view=True, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)
Render the tree to a flie
- Parameters
out_file (file name to save to)
format (str, default ‘pdf’) – The file format to render to; must be supported by graphviz
view (bool, default True) – Whether to open the rendered result with the default application.
feature_names (list of str, optional) – Names of each of the features.
treatment_names (list of str, optional) – Names of each of the treatments
max_depth (int, optional) – The maximum tree depth to plot
filled (bool, default False) – When set to
True
, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.leaves_parallel (bool, default True) – When set to
True
, draw all leaf nodes at the bottom of the tree.rotate (bool, default False) – When set to
True
, orient tree left to right rather than top-down.rounded (bool, default True) – When set to
True
, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.special_characters (bool, default False) – When set to
False
, ignore special characters for PostScript compatibility.precision (int, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.