econml.sklearn_extensions.model_selection.WeightedStratifiedKFold

class econml.sklearn_extensions.model_selection.WeightedStratifiedKFold(n_splits=3, n_trials=10, shuffle=False, random_state=None)[source]

Bases: econml.sklearn_extensions.model_selection.WeightedKFold

Stratified K-Folds cross-validator for weighted data.

Provides train/test indices to split data in train/test sets. Split dataset into k folds of roughly equal size and equal total weight.

The default is to try sklearn.model_selection.StratifiedKFold a number of trials to find a weight-balanced k-way split. If it cannot find such a split, it will fall back onto a more rigorous weight stratification algorithm.

Parameters
  • n_splits (int, default 3) – Number of folds. Must be at least 2.

  • n_trials (int, default 10) – Number of times to try sklearn.model_selection.StratifiedKFold before falling back to another weight stratification algorithm.

  • shuffle (bool, optional) – Whether to shuffle the data before splitting into batches.

  • random_state (int, RandomState instance, or None, default None) – If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by np.random. Used when shuffle == True.

__init__(n_splits=3, n_trials=10, shuffle=False, random_state=None)

Methods

__init__([n_splits, n_trials, shuffle, ...])

get_n_splits(X, y[, groups])

Return the number of splitting iterations in the cross-validator.

split(X, y[, sample_weight])

Generate indices to split data into training and test set.

get_n_splits(X, y, groups=None)[source]

Return the number of splitting iterations in the cross-validator.

Parameters
  • X (object) – Always ignored, exists for compatibility.

  • y (object) – Always ignored, exists for compatibility.

  • groups (object) – Always ignored, exists for compatibility.

Returns

n_splits – Returns the number of splitting iterations in the cross-validator.

Return type

int

split(X, y, sample_weight=None)[source]

Generate indices to split data into training and test set.

Parameters
  • X (array_like, shape (n_samples, n_features)) – Training data, where n_samples is the number of samples and n_features is the number of features.

  • y (array_like, shape (n_samples,)) – The target variable for supervised learning problems.

  • sample_weight (array_like, shape (n_samples,)) – Weights associated with the training data.