econml.sklearn_extensions.model_selection.WeightedStratifiedKFold

class econml.sklearn_extensions.model_selection.WeightedStratifiedKFold(n_splits=3, n_trials=10, shuffle=False, random_state=None)[source]

Bases: econml.sklearn_extensions.model_selection.WeightedKFold

Stratified K-Folds cross-validator for weighted data.

Provides train/test indices to split data in train/test sets. Split dataset into k folds of roughly equal size and equal total weight.

The default is to try sklearn.model_selection.StratifiedKFold a number of trials to find a weight-balanced k-way split. If it cannot find such a split, it will fall back onto a more rigorous weight stratification algorithm.

Parameters
  • n_splits (int, default=3) – Number of folds. Must be at least 2.

  • n_trials (int, default=10) – Number of times to try sklearn.model_selection.StratifiedKFold before falling back to another weight stratification algorithm.

  • shuffle (boolean, optional) – Whether to shuffle the data before splitting into batches.

  • random_state (int, RandomState instance or None, optional, default=None) – If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random. Used when shuffle == True.

__init__(n_splits=3, n_trials=10, shuffle=False, random_state=None)

Initialize self. See help(type(self)) for accurate signature.

Methods

__init__([n_splits, n_trials, shuffle, …])

Initialize self.

split(X, y[, sample_weight])

Generate indices to split data into training and test set.

split(X, y, sample_weight=None)[source]

Generate indices to split data into training and test set.

Parameters
  • X (array-like, shape (n_samples, n_features)) – Training data, where n_samples is the number of samples and n_features is the number of features.

  • y (array-like, shape (n_samples,)) – The target variable for supervised learning problems.

  • sample_weight (array-like, shape (n_samples,)) – Weights associated with the training data.