Skip to content

Commit 6a3c5de

Browse files
author
Guillaume Lemaitre
committed
Update the ensemble method
1 parent fdbc6e4 commit 6a3c5de

27 files changed

+57
-294
lines changed

imblearn/ensemble/balance_cascade.py

Lines changed: 28 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
import numpy as np
55

66
from sklearn.utils import check_X_y
7+
from sklearn.utils import check_random_state
78

8-
from .ensemble_sampler import EnsembleSampler
9+
from ..base import SamplerMixin
910

1011

11-
class BalanceCascade(EnsembleSampler):
12+
ESTIMATOR_KIND = ('knn', 'decision-tree', 'random-forest', 'adaboost',
13+
'gradient-boosting', 'linear-svm')
14+
15+
16+
class BalanceCascade(SamplerMixin):
1217
"""Create an ensemble of balanced sets by iteratively under-sampling the
1318
imbalanced dataset using an estimator.
1419
@@ -27,8 +32,11 @@ class BalanceCascade(EnsembleSampler):
2732
Whether or not to return the indices of the samples randomly
2833
selected from the majority class.
2934
30-
random_state : int or None, optional (default=None)
31-
Seed for random number generation.
35+
random_state : int, RandomState instance or None, optional (default=None)
36+
If int, random_state is the seed used by the random number generator;
37+
If RandomState instance, random_state is the random number generator;
38+
If None, the random number generator is the RandomState instance used
39+
by np.random.
3240
3341
verbose : bool, optional (default=True)
3442
Whether or not to print information about the processing.
@@ -52,15 +60,6 @@ class BalanceCascade(EnsembleSampler):
5260
5361
Attributes
5462
----------
55-
ratio : str or float
56-
If 'auto', the ratio will be defined automatically to balance
57-
the dataset. Otherwise, the ratio is defined as the number
58-
of samples in the minority class over the the number of samples
59-
in the majority class.
60-
61-
random_state : int or None
62-
Seed for random number generation.
63-
6463
min_c_ : str or int
6564
The identifier of the minority class.
6665
@@ -91,89 +90,16 @@ class BalanceCascade(EnsembleSampler):
9190
def __init__(self, ratio='auto', return_indices=False, random_state=None,
9291
verbose=True, n_max_subset=None, classifier='knn',
9392
bootstrap=True, **kwargs):
94-
"""Initialise the balance cascade object.
95-
96-
Parameters
97-
----------
98-
ratio : str or float, optional (default='auto')
99-
If 'auto', the ratio will be defined automatically to balance
100-
the dataset. Otherwise, the ratio is defined as the number
101-
of samples in the minority class over the the number of samples
102-
in the majority class.
103-
104-
return_indices : bool, optional (default=True)
105-
Whether or not to return the indices of the samples randomly
106-
selected from the majority class.
107-
108-
random_state : int or None, optional (default=None)
109-
Seed for random number generation.
110-
111-
verbose : bool, optional (default=True)
112-
Whether or not to print information about the processing.
113-
114-
n_max_subset : int or None, optional (default=None)
115-
Maximum number of subsets to generate. By default, all data from
116-
the training will be selected that could lead to a large number of
117-
subsets. We can probably reduced this number empirically.
118-
119-
classifier : str, optional (default='knn')
120-
The classifier that will be selected to confront the prediction
121-
with the real labels. The choices are the following: 'knn',
122-
'decision-tree', 'random-forest', 'adaboost', 'gradient-boosting'
123-
and 'linear-svm'.
124-
125-
bootstrap : bool, optional (default=True)
126-
Whether to bootstrap the data before each iteration.
127-
128-
**kwargs : keywords
129-
The parameters associated with the classifier provided.
130-
131-
Returns
132-
-------
133-
None
134-
135-
"""
13693
super(BalanceCascade, self).__init__(ratio=ratio,
137-
return_indices=return_indices,
138-
verbose=verbose,
139-
random_state=random_state)
140-
# Define the classifier to use
141-
list_classifier = ('knn', 'decision-tree', 'random-forest', 'adaboost',
142-
'gradient-boosting', 'linear-svm')
143-
if classifier in list_classifier:
144-
self.classifier = classifier
145-
else:
146-
raise NotImplementedError
94+
verbose=verbose)
95+
self.return_indices = return_indices
96+
self.random_state = random_state
97+
self.classifier = classifier
14798
self.n_max_subset = n_max_subset
14899
self.bootstrap = bootstrap
149100
self.kwargs = kwargs
150101

151-
def fit(self, X, y):
152-
"""Find the classes statistics before to perform sampling.
153-
154-
Parameters
155-
----------
156-
X : ndarray, shape (n_samples, n_features)
157-
Matrix containing the data which have to be sampled.
158-
159-
y : ndarray, shape (n_samples, )
160-
Corresponding label for each sample in X.
161-
162-
Returns
163-
-------
164-
self : object,
165-
Return self.
166-
167-
"""
168-
# Check the consistency of X and y
169-
X, y = check_X_y(X, y)
170-
171-
# Call the parent function
172-
super(BalanceCascade, self).fit(X, y)
173-
174-
return self
175-
176-
def sample(self, X, y):
102+
def _sample(self, X, y):
177103
"""Resample the dataset.
178104
179105
Parameters
@@ -197,10 +123,11 @@ def sample(self, X, y):
197123
containing the which samples have been selected.
198124
199125
"""
200-
# Check the consistency of X and y
201-
X, y = check_X_y(X, y)
202126

203-
super(BalanceCascade, self).sample(X, y)
127+
if self.classifier not in ESTIMATOR_KIND:
128+
raise NotImplementedError
129+
130+
random_state = check_random_state(self.random_state)
204131

205132
# Define the classifier to use
206133
if self.classifier == 'knn':
@@ -210,25 +137,26 @@ def sample(self, X, y):
210137
elif self.classifier == 'decision-tree':
211138
from sklearn.tree import DecisionTreeClassifier
212139
classifier = DecisionTreeClassifier(
140+
random_state=random_state,
213141
**self.kwargs)
214142
elif self.classifier == 'random-forest':
215143
from sklearn.ensemble import RandomForestClassifier
216144
classifier = RandomForestClassifier(
217-
random_state=self.random_state,
145+
random_state=random_state,
218146
**self.kwargs)
219147
elif self.classifier == 'adaboost':
220148
from sklearn.ensemble import AdaBoostClassifier
221149
classifier = AdaBoostClassifier(
222-
random_state=self.random_state,
150+
random_state=random_state,
223151
**self.kwargs)
224152
elif self.classifier == 'gradient-boosting':
225153
from sklearn.ensemble import GradientBoostingClassifier
226154
classifier = GradientBoostingClassifier(
227-
random_state=self.random_state,
155+
random_state=random_state,
228156
**self.kwargs)
229157
elif self.classifier == 'linear-svm':
230158
from sklearn.svm import LinearSVC
231-
classifier = LinearSVC(random_state=self.random_state,
159+
classifier = LinearSVC(random_state=random_state,
232160
**self.kwargs)
233161
else:
234162
raise NotImplementedError
@@ -267,8 +195,7 @@ def sample(self, X, y):
267195
# Generate an appropriate number of index to extract
268196
# from the majority class depending of the false classification
269197
# rate of the previous iteration
270-
np.random.seed(self.random_state)
271-
idx_sel_from_maj = np.random.choice(np.nonzero(b_sel_N)[0],
198+
idx_sel_from_maj = random_state.choice(np.nonzero(b_sel_N)[0],
272199
size=num_samples,
273200
replace=False)
274201
idx_sel_from_maj = np.concatenate((idx_mis_class,
@@ -296,7 +223,7 @@ def sample(self, X, y):
296223
self.bootstrap):
297224
# Apply a bootstrap on x_data
298225
curr_sample_weight = np.ones((y_data.size,), dtype=np.float64)
299-
indices = np.random.randint(0, y_data.size, y_data.size)
226+
indices = random_state.randint(0, y_data.size, y_data.size)
300227
sample_counts = np.bincount(indices, minlength=y_data.size)
301228
curr_sample_weight *= sample_counts
302229

imblearn/ensemble/easy_ensemble.py

Lines changed: 11 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
from sklearn.utils import check_X_y
77

8-
from .ensemble_sampler import EnsembleSampler
8+
from ..base import SamplerMixin
99
from ..under_sampling import RandomUnderSampler
1010

1111

12-
class EasyEnsemble(EnsembleSampler):
12+
class EasyEnsemble(SamplerMixin):
1313
"""Create an ensemble sets by iteratively applying random under-sampling.
1414
1515
This method iteratively select a random subset and make an ensemble of the
@@ -27,8 +27,11 @@ class EasyEnsemble(EnsembleSampler):
2727
Whether or not to return the indices of the samples randomly
2828
selected from the majority class.
2929
30-
random_state : int or None, optional (default=None)
31-
Seed for random number generation.
30+
random_state : int, RandomState instance or None, optional (default=None)
31+
If int, random_state is the seed used by the random number generator;
32+
If RandomState instance, random_state is the random number generator;
33+
If None, the random number generator is the RandomState instance used
34+
by np.random.
3235
3336
verbose : bool, optional (default=True)
3437
Whether or not to print information about the processing.
@@ -41,15 +44,6 @@ class EasyEnsemble(EnsembleSampler):
4144
4245
Attributes
4346
----------
44-
ratio : str or float
45-
If 'auto', the ratio will be defined automatically to balance
46-
the dataset. Otherwise, the ratio is defined as the number
47-
of samples in the minority class over the the number of samples
48-
in the majority class.
49-
50-
random_state : int or None
51-
Seed for random number generation.
52-
5347
min_c_ : str or int
5448
The identifier of the minority class.
5549
@@ -78,70 +72,14 @@ class EasyEnsemble(EnsembleSampler):
7872

7973
def __init__(self, ratio='auto', return_indices=False, verbose=True,
8074
random_state=None, replacement=False, n_subsets=10):
81-
"""Initialise the easy ensenble object.
82-
83-
Parameters
84-
----------
85-
ratio : str or float, optional (default='auto')
86-
If 'auto', the ratio will be defined automatically to balance
87-
the dataset. Otherwise, the ratio is defined as the number
88-
of samples in the minority class over the the number of samples
89-
in the majority class.
90-
91-
return_indices : bool, optional (default=True)
92-
Whether or not to return the indices of the samples randomly
93-
selected from the majority class.
94-
95-
random_state : int or None, optional (default=None)
96-
Seed for random number generation.
97-
98-
verbose : bool, optional (default=True)
99-
Whether or not to print information about the processing.
100-
101-
replacement : bool, optional (default=False)
102-
Whether or not to sample randomly with replacement or not.
103-
104-
n_subsets : int, optional (default=10)
105-
Number of subsets to generate.
106-
107-
Returns
108-
-------
109-
None
110-
111-
"""
11275
super(EasyEnsemble, self).__init__(ratio=ratio,
113-
return_indices=return_indices,
114-
verbose=verbose,
115-
random_state=random_state)
76+
verbose=verbose)
77+
self.return_indices = return_indices
78+
self.random_state = random_state
11679
self.replacement = replacement
11780
self.n_subsets = n_subsets
11881

119-
def fit(self, X, y):
120-
"""Find the classes statistics before to perform sampling.
121-
122-
Parameters
123-
----------
124-
X : ndarray, shape (n_samples, n_features)
125-
Matrix containing the data which have to be sampled.
126-
127-
y : ndarray, shape (n_samples, )
128-
Corresponding label for each sample in X.
129-
130-
Returns
131-
-------
132-
self : object,
133-
Return self.
134-
135-
"""
136-
# Check the consistency of X and y
137-
X, y = check_X_y(X, y)
138-
139-
# Call the parent function
140-
super(EasyEnsemble, self).fit(X, y)
141-
142-
return self
143-
144-
def sample(self, X, y):
82+
def _sample(self, X, y):
14583
"""Resample the dataset.
14684
14785
Parameters
@@ -165,10 +103,6 @@ def sample(self, X, y):
165103
containing the which samples have been selected.
166104
167105
"""
168-
# Check the consistency of X and y
169-
X, y = check_X_y(X, y)
170-
171-
super(EasyEnsemble, self).sample(X, y)
172106

173107
X_resampled = []
174108
y_resampled = []

0 commit comments

Comments
 (0)