econml.sklearn_extensions.linear_model.StatsModelsLinearRegression
- class econml.sklearn_extensions.linear_model.StatsModelsLinearRegression(fit_intercept=True, cov_type='HC0', *, enable_federation=False)[source]
Bases:
_StatsModelsWrapper
Class which mimics weighted linear regression from the statsmodels package.
However, unlike statsmodels WLS, this class also supports sample variances in addition to sample weights, which enables more accurate inference when working with summarized data.
- Parameters:
fit_intercept (bool, default True) – Whether to fit an intercept in this model
cov_type (string, default “HC0”) – The covariance approach to use. Supported values are “HCO”, “HC1”, and “nonrobust”.
enable_federation (bool, default False) – Whether to enable federation (aggregating this model’s results with other models in a distributed setting). This requires additional memory proportional to the number of columns in X to the fourth power.
Methods
__init__
([fit_intercept, cov_type, ...])aggregate
(models)Aggregate multiple models into one.
coef__interval
([alpha])Get a confidence interval bounding the fitted coefficients.
fit
(X, y[, sample_weight, freq_weight, ...])Fits the model.
Get metadata routing of this object.
get_params
([deep])Get parameters for this estimator.
intercept__interval
([alpha])Get a confidence interval bounding the intercept(s) (or 0 if no intercept was fit).
predict
(X)Predicts the output given an array of instances.
predict_interval
(X[, alpha])Get a confidence interval bounding the prediction.
Get the standard error of the predictions.
set_fit_request
(*[, freq_weight, ...])Request metadata passed to the
fit
method.set_params
(**params)Set the parameters of this estimator.
Attributes
Get the model's coefficients on the covariates.
Get the standard error of the fitted coefficients.
Get the intercept(s) (or 0 if no intercept was fit).
Get the standard error of the intercept(s) (or 0 if no intercept was fit).
- static aggregate(models: List[StatsModelsLinearRegression])[source]
Aggregate multiple models into one.
- Parameters:
models (list of StatsModelsLinearRegression) – The models to aggregate
- Returns:
agg_model – The aggregated model
- Return type:
- coef__interval(alpha=0.05)
Get a confidence interval bounding the fitted coefficients.
- Parameters:
alpha (float, default 0.05) – The confidence level. Will calculate the alpha/2-quantile and the (1-alpha/2)-quantile of the parameter distribution as confidence interval
- Returns:
coef__interval – The lower and upper bounds of the confidence interval of the coefficients
- Return type:
{tuple ((p, d) array, (p,d) array), tuple ((d,) array, (d,) array)}
- fit(X, y, sample_weight=None, freq_weight=None, sample_var=None)[source]
Fits the model.
- Parameters:
X ((N, d) nd array_like) – co-variates
y ({(N,), (N, p)} nd array_like) – output variable(s)
sample_weight ((N,) array_like or None) – Individual weights for each sample. If None, it assumes equal weight.
freq_weight ((N, ) array_like of int or None) – Weight for the observation. Observation i is treated as the mean outcome of freq_weight[i] independent observations. When
sample_var
is not None, this should be provided.sample_var ({(N,), (N, p)} nd array_like or None) – Variance of the outcome(s) of the original freq_weight[i] observations that were used to compute the mean outcome represented by observation i.
- Returns:
self
- Return type:
- get_metadata_routing()
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
routing – A
MetadataRequest
encapsulating routing information.- Return type:
MetadataRequest
- get_params(deep=True)
Get parameters for this estimator.
- Parameters:
deep (bool, default=True) – If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
params – Parameter names mapped to their values.
- Return type:
- intercept__interval(alpha=0.05)
Get a confidence interval bounding the intercept(s) (or 0 if no intercept was fit).
- Parameters:
alpha (float, default 0.05) – The confidence level. Will calculate the alpha/2-quantile and the (1-alpha/2)-quantile of the parameter distribution as confidence interval
- Returns:
intercept__interval – The lower and upper bounds of the confidence interval of the intercept(s)
- Return type:
- predict(X)
Predicts the output given an array of instances.
- Parameters:
X ((n, d) array_like) – The covariates on which to predict
- Returns:
predictions – The predicted mean outcomes
- Return type:
{(n,) array, (n,p) array}
- predict_interval(X, alpha=0.05)
Get a confidence interval bounding the prediction.
- Parameters:
X ((n, d) array_like) – The covariates on which to predict
alpha (float, default 0.05) – The confidence level. Will calculate the alpha/2-quantile and the (1-alpha/2)-quantile of the parameter distribution as confidence interval
- Returns:
prediction_intervals – The lower and upper bounds of the confidence intervals of the predicted mean outcomes
- Return type:
{tuple ((n,) array, (n,) array), tuple ((n,p) array, (n,p) array)}
- prediction_stderr(X)
Get the standard error of the predictions.
- Parameters:
X ((n, d) array_like) – The covariates at which to predict
- Returns:
prediction_stderr – The standard error of each coordinate of the output at each point we predict
- Return type:
(n, p) array_like
- set_fit_request(*, freq_weight: bool | None | str = '$UNCHANGED$', sample_var: bool | None | str = '$UNCHANGED$', sample_weight: bool | None | str = '$UNCHANGED$') StatsModelsLinearRegression
Request metadata passed to the
fit
method.Note that this method is only relevant if
enable_metadata_routing=True
(seesklearn.set_config()
). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed tofit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it tofit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline
. Otherwise it has no effect.- Parameters:
freq_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
freq_weight
parameter infit
.sample_var (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
sample_var
parameter infit
.sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for
sample_weight
parameter infit
.
- Returns:
self – The updated object.
- Return type:
- set_params(**params)
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline
). The latter have parameters of the form<component>__<parameter>
so that it’s possible to update each component of a nested object.- Parameters:
**params (dict) – Estimator parameters.
- Returns:
self – Estimator instance.
- Return type:
estimator instance
- property coef_
Get the model’s coefficients on the covariates.
- Returns:
coef_ – The coefficients of the variables in the linear regression. If label y was p-dimensional, then the result is a matrix of coefficents, whose p-th row containts the coefficients corresponding to the p-th coordinate of the label.
- Return type:
{(d,), (p, d)} nd array_like
- property coef_stderr_
Get the standard error of the fitted coefficients.
- Returns:
coef_stderr_ – The standard error of the coefficients
- Return type:
{(d,), (p, d)} nd array_like
- property intercept_
Get the intercept(s) (or 0 if no intercept was fit).
- Returns:
intercept_ – The intercept of the linear regresion. If label y was p-dimensional, then the result is a vector whose p-th entry containts the intercept corresponding to the p-th coordinate of the label.
- Return type:
float or (p,) nd array_like