Source code for econml.score.ensemble_cate

# Copyright (c) PyWhy contributors. All rights reserved.
# Licensed under the MIT License.

import numpy as np
from sklearn.utils.validation import check_array
from .._cate_estimator import BaseCateEstimator, LinearCateEstimator

[docs]class EnsembleCateEstimator: """ A CATE estimator that represents a weighted ensemble of many CATE estimators. Returns their weighted effect prediction. Parameters ---------- cate_models : list of BaseCateEstimator objects A list of fitted cate estimator objects that will be used in the ensemble. The models are passed by reference, and not copied internally, because we need the fitted objects, so any change to the passed models will affect the internal predictions (e.g. if the input models are refitted). weights : np.ndarray of shape (len(cate_models),) The weight placed on each model. Weights must be non-positive. The ensemble will predict effects based on the weighted average predictions of the cate_models estiamtors, weighted by the corresponding weight in `weights`. """
[docs] def __init__(self, *, cate_models, weights): self.cate_models = cate_models self.weights = weights
[docs] def effect(self, X=None, *, T0=0, T1=1): return np.average([mdl.effect(X=X, T0=T0, T1=T1) for mdl in self.cate_models], weights=self.weights, axis=0)
effect.__doc__ = BaseCateEstimator.effect.__doc__
[docs] def marginal_effect(self, T, X=None): return np.average([mdl.marginal_effect(T, X=X) for mdl in self.cate_models], weights=self.weights, axis=0)
marginal_effect.__doc__ = BaseCateEstimator.marginal_effect.__doc__
[docs] def const_marginal_effect(self, X=None): if np.any([not hasattr(mdl, 'const_marginal_effect') for mdl in self.cate_models]): raise ValueError("One of the base CATE models in parameter `cate_models` does not support " "the `const_marginal_effect` method.") return np.average([mdl.const_marginal_effect(X=X) for mdl in self.cate_models], weights=self.weights, axis=0)
const_marginal_effect.__doc__ = LinearCateEstimator.const_marginal_effect.__doc__ @property def cate_models(self): return self._cate_models @cate_models.setter def cate_models(self, value): if (not isinstance(value, list)) or (not np.all([isinstance(model, BaseCateEstimator) for model in value])): raise ValueError('Parameter `cate_models` should be a list of `BaseCateEstimator` objects.') self._cate_models = value @property def weights(self): return self._weights @weights.setter def weights(self, value): weights = check_array(value, accept_sparse=False, ensure_2d=False, allow_nd=False, dtype='numeric', force_all_finite=True) if np.any(weights < 0): raise ValueError("All weights in parameter `weights` must be non-negative.") self._weights = weights