econml.cate_interpreter.SingleTreeCateInterpreter

class econml.cate_interpreter.SingleTreeCateInterpreter(*, include_model_uncertainty=False, uncertainty_level=0.1, uncertainty_only_on_leaves=True, splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0)[source]

Bases: econml.cate_interpreter._interpreters._SingleTreeInterpreter

An interpreter for the effect estimated by a CATE estimator

Parameters
  • include_model_uncertainty (bool, optional, default False) – Whether to include confidence interval information when building a simplified model of the cate model. If set to True, then cate estimator needs to support the const_marginal_ate_inference method.

  • uncertainty_level (double, optional, default .1) – The uncertainty level for the confidence intervals to be constructed and used in the simplified model creation. If value=alpha then a multitask decision tree will be built such that all samples in a leaf have similar target prediction but also similar alpha confidence intervals.

  • uncertainty_only_on_leaves (bool, optional, default True) – Whether uncertainty information should be displayed only on leaf nodes. If False, then interpretation can be slightly slower, especially for cate models that have a computationally expensive inference method.

  • splitter (string, optional, default “best”) – The strategy used to choose the split at each node. Supported strategies are “best” to choose the best split and “random” to choose the best random split.

  • max_depth (int or None, optional, default None) – The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than min_samples_split samples.

  • min_samples_split (int, float, optional, default 2) – The minimum number of samples required to split an internal node:

    • If int, then consider min_samples_split as the minimum number.

    • If float, then min_samples_split is a fraction and ceil(min_samples_split * n_samples) are the minimum number of samples for each split.

  • min_samples_leaf (int, float, optional, default 1) – The minimum number of samples required to be at a leaf node. A split point at any depth will only be considered if it leaves at least min_samples_leaf training samples in each of the left and right branches. This may have the effect of smoothing the model, especially in regression.

    • If int, then consider min_samples_leaf as the minimum number.

    • If float, then min_samples_leaf is a fraction and ceil(min_samples_leaf * n_samples) are the minimum number of samples for each node.

  • min_weight_fraction_leaf (float, optional, default 0.) – The minimum weighted fraction of the sum total of weights (of all the input samples) required to be at a leaf node. Samples have equal weight when sample_weight is not provided.

  • max_features (int, float or {“auto”, “sqrt”, “log2”}, default=None) – The number of features to consider when looking for the best split:

    • If int, then consider max_features features at each split.

    • If float, then max_features is a fraction and int(max_features * n_features) features are considered at each split.

    • If “auto”, then max_features=n_features.

    • If “sqrt”, then max_features=sqrt(n_features).

    • If “log2”, then max_features=log2(n_features).

    • If None, then max_features=n_features.

    Note: the search for a split does not stop until at least one valid partition of the node samples is found, even if it requires to effectively inspect more than max_features features.

  • random_state (int, RandomState instance or None, optional, default None) – If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random.

  • max_leaf_nodes (int or None, optional, default None) – Grow a tree with max_leaf_nodes in best-first fashion. Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes.

  • min_impurity_decrease (float, optional, default 0.) – A node will be split if this split induces a decrease of the impurity greater than or equal to this value.

    The weighted impurity decrease equation is the following:

    N_t / N * (impurity - N_t_R / N_t * right_impurity
                        - N_t_L / N_t * left_impurity)
    

    where N is the total number of samples, N_t is the number of samples at the current node, N_t_L is the number of samples in the left child, and N_t_R is the number of samples in the right child. N, N_t, N_t_R and N_t_L all refer to the weighted sum, if sample_weight is passed.

__init__(*, include_model_uncertainty=False, uncertainty_level=0.1, uncertainty_only_on_leaves=True, splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0)[source]

Initialize self. See help(type(self)) for accurate signature.

Methods

__init__(*[, include_model_uncertainty, …])

Initialize self.

export_graphviz([out_file, feature_names, …])

Export a graphviz dot file representing the learned tree model

interpret(cate_estimator, X)

Interpret the heterogeneity of a CATE estimator when applied to a set of features

plot([ax, title, feature_names, …])

Exports policy trees to matplotlib

render(out_file[, format, view, …])

Render the tree to a flie

Attributes

node_dict_

tree_model_

export_graphviz(out_file=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)

Export a graphviz dot file representing the learned tree model

Parameters
  • out_file (file object or string, optional, default None) – Handle or name of the output file. If None, the result is returned as a string.

  • feature_names (list of strings, optional, default None) – Names of each of the features.

  • treatment_names (list of strings, optional, default None) – Names of each of the treatments

  • max_depth (int or None, optional, default None) – The maximum tree depth to plot

  • filled (bool, optional, default False) – When set to True, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.

  • leaves_parallel (bool, optional, default True) – When set to True, draw all leaf nodes at the bottom of the tree.

  • rotate (bool, optional, default False) – When set to True, orient tree left to right rather than top-down.

  • rounded (bool, optional, default True) – When set to True, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.

  • special_characters (bool, optional, default False) – When set to False, ignore special characters for PostScript compatibility.

  • precision (int, optional, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.

interpret(cate_estimator, X)[source]

Interpret the heterogeneity of a CATE estimator when applied to a set of features

Parameters
  • cate_estimator (LinearCateEstimator) – The fitted estimator to interpret

  • X (array-like) – The features against which to interpret the estimator; must be compatible shape-wise with the features used to fit the estimator

Returns

self

Return type

object instance

plot(ax=None, title=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, rounded=True, precision=3, fontsize=None)

Exports policy trees to matplotlib

Parameters
  • ax (matplotlib.axes.Axes, optional, default None) – The axes on which to plot

  • title (string, optional, default None) – A title for the final figure to be printed at the top of the page.

  • feature_names (list of strings, optional, default None) – Names of each of the features.

  • treatment_names (list of strings, optional, default None) – Names of each of the treatments

  • max_depth (int or None, optional, default None) – The maximum tree depth to plot

  • filled (bool, optional, default False) – When set to True, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.

  • rounded (bool, optional, default True) – When set to True, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.

  • precision (int, optional, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.

  • fontsize (int, optional, default None) – Font size for text

render(out_file, format='pdf', view=True, feature_names=None, treatment_names=None, max_depth=None, filled=True, leaves_parallel=True, rotate=False, rounded=True, special_characters=False, precision=3)

Render the tree to a flie

Parameters
  • out_file (file name to save to)

  • format (string, optional, default ‘pdf’) – The file format to render to; must be supported by graphviz

  • view (bool, optional, default True) – Whether to open the rendered result with the default application.

  • feature_names (list of strings, optional, default None) – Names of each of the features.

  • treatment_names (list of strings, optional, default None) – Names of each of the treatments

  • max_depth (int or None, optional, default None) – The maximum tree depth to plot

  • filled (bool, optional, default False) – When set to True, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.

  • leaves_parallel (bool, optional, default True) – When set to True, draw all leaf nodes at the bottom of the tree.

  • rotate (bool, optional, default False) – When set to True, orient tree left to right rather than top-down.

  • rounded (bool, optional, default True) – When set to True, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.

  • special_characters (bool, optional, default False) – When set to False, ignore special characters for PostScript compatibility.

  • precision (int, optional, default 3) – Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.