econml.sklearn_extensions.model_selection.WeightedStratifiedKFold
- class econml.sklearn_extensions.model_selection.WeightedStratifiedKFold(n_splits=3, n_trials=10, shuffle=False, random_state=None)[source]
Bases:
econml.sklearn_extensions.model_selection.WeightedKFold
Stratified K-Folds cross-validator for weighted data.
Provides train/test indices to split data in train/test sets. Split dataset into k folds of roughly equal size and equal total weight.
The default is to try sklearn.model_selection.StratifiedKFold a number of trials to find a weight-balanced k-way split. If it cannot find such a split, it will fall back onto a more rigorous weight stratification algorithm.
- Parameters
n_splits (int, default 3) – Number of folds. Must be at least 2.
n_trials (int, default 10) – Number of times to try sklearn.model_selection.StratifiedKFold before falling back to another weight stratification algorithm.
shuffle (bool, optional) – Whether to shuffle the data before splitting into batches.
random_state (int, RandomState instance, or None, default None) – If int, random_state is the seed used by the random number generator; If
RandomState
instance, random_state is the random number generator; If None, the random number generator is theRandomState
instance used bynp.random
. Used whenshuffle
== True.
- __init__(n_splits=3, n_trials=10, shuffle=False, random_state=None)
Methods
__init__
([n_splits, n_trials, shuffle, ...])get_n_splits
(X, y[, groups])Return the number of splitting iterations in the cross-validator.
split
(X, y[, sample_weight])Generate indices to split data into training and test set.
- get_n_splits(X, y, groups=None)[source]
Return the number of splitting iterations in the cross-validator.
- Parameters
X (object) – Always ignored, exists for compatibility.
y (object) – Always ignored, exists for compatibility.
groups (object) – Always ignored, exists for compatibility.
- Returns
n_splits – Returns the number of splitting iterations in the cross-validator.
- Return type
- split(X, y, sample_weight=None)[source]
Generate indices to split data into training and test set.
- Parameters
X (array_like, shape (n_samples, n_features)) – Training data, where n_samples is the number of samples and n_features is the number of features.
y (array_like, shape (n_samples,)) – The target variable for supervised learning problems.
sample_weight (array_like, shape (n_samples,)) – Weights associated with the training data.