Source code for econml.federated_learning

# Copyright (c) PyWhy contributors. All rights reserved.
# Licensed under the MIT License.

import numpy as np
from sklearn import clone

from econml.utilities import check_input_arrays
from ._cate_estimator import (LinearCateEstimator, TreatmentExpansionMixin,
                              StatsModelsCateEstimatorMixin, StatsModelsCateEstimatorDiscreteMixin)
from .dml import LinearDML
from .inference import StatsModelsInference, StatsModelsInferenceDiscrete
from .sklearn_extensions.linear_model import StatsModelsLinearRegression
from typing import List

# TODO: This could be extended to also work with our sparse and 2SLS estimators,
#       if we add an aggregate method to them
#       Remember to update the docs if this changes

[docs]class FederatedEstimator(TreatmentExpansionMixin, LinearCateEstimator): """ A class for federated learning using LinearDML, LinearDRIV, and LinearDRLearner estimators. Parameters ---------- estimators : list of LinearDML, LinearDRIV, or LinearDRLearner List of estimators to aggregate (all of the same type), which must already have been fit. """
[docs] def __init__(self, estimators: List[LinearDML]): self.estimators = estimators dummy_est = clone(self.estimators[0], safe=False) # used to extract various attributes later infs = [est._inference for est in self.estimators] assert ( all(isinstance(inf, StatsModelsInference) for inf in infs) or all(isinstance(inf, StatsModelsInferenceDiscrete) for inf in infs) ), "All estimators must use either StatsModelsInference or StatsModelsInferenceDiscrete" cov_types = set(inf.cov_type for inf in infs) assert len(cov_types) == 1, f"All estimators must use the same covariance type, got {cov_types}" if isinstance(infs[0], StatsModelsInference): inf = StatsModelsInference(cov_type=cov_types.pop()) cate_est_type = StatsModelsCateEstimatorMixin self.model_final_ = StatsModelsLinearRegression.aggregate([est.model_final_ for est in self.estimators]) inf.model_final = self.model_final_ inf.bias_part_of_coef = dummy_est.bias_part_of_coef else: inf = StatsModelsInferenceDiscrete(cov_type=cov_types.pop()) cate_est_type = StatsModelsCateEstimatorDiscreteMixin self.fitted_models_final = [ StatsModelsLinearRegression.aggregate(models) for models in zip(*[est.fitted_models_final for est in self.estimators])] inf.fitted_models_final = self.fitted_models_final # mix in the appropriate inference class self.__class__ = type("FederatedEstimator", (FederatedEstimator, cate_est_type), {}) # assign all of the attributes from the dummy estimator that would normally be assigned during fitting # TODO: This seems hacky; is there a better abstraction to maintain these? # This should also include bias_part_of_coef, model_final_, and fitted_models_final above inf.featurizer = dummy_est.featurizer_ if hasattr(dummy_est, 'featurizer_') else None inf._est = self self._d_t = inf._d_t = dummy_est._d_t self._d_y = inf._d_y = dummy_est._d_y self.d_t = inf.d_t = inf._d_t[0] if inf._d_t else 1 self.d_y = inf.d_y = inf._d_y[0] if inf._d_y else 1 self._d_t_in = inf._d_t_in = dummy_est._d_t_in self.fit_cate_intercept_ = inf.fit_cate_intercept = dummy_est.fit_cate_intercept self._inference = inf # Assign treatment expansion attributes self.transformer = dummy_est.transformer
# Methods needed to implement the LinearCateEstimator interface
[docs] def const_marginal_effect(self, X=None): X, = check_input_arrays(X) return self._inference.const_marginal_effect_inference(X).point_estimate
[docs] def fit(self, *args, **kwargs): """ This method should not be called; it is included only for compatibility with the CATE estimation APIs """ raise NotImplementedError("FederatedEstimator does not support fit")
# Methods needed to implement the LinearFinalModelCateEstimatorMixin def bias_part_of_coef(self): return self._inference.bias_part_of_coef