11"""Class to perform under-sampling using balace cascade."""
22from __future__ import print_function
33
4+ import warnings
5+
46import numpy as np
7+
8+ from sklearn .base import ClassifierMixin
9+ from sklearn .neighbors import KNeighborsClassifier
510from sklearn .utils import check_random_state
11+ from sklearn .utils .validation import has_fit_parameter
612
7- from .. base import BaseBinarySampler
13+ from six import string_types
814
9- ESTIMATOR_KIND = ('knn' , 'decision-tree' , 'random-forest' , 'adaboost' ,
10- 'gradient-boosting' , 'linear-svm' )
15+ from ..base import BaseBinarySampler
1116
1217
1318class BalanceCascade (BaseBinarySampler ):
@@ -40,18 +45,29 @@ class BalanceCascade(BaseBinarySampler):
4045 the training will be selected that could lead to a large number of
4146 subsets. We can probably deduce this number empirically.
4247
43- classifier : str, optional (default='knn' )
48+ classifier : str, optional (default=None )
4449 The classifier that will be selected to confront the prediction
4550 with the real labels. The choices are the following: 'knn',
4651 'decision-tree', 'random-forest', 'adaboost', 'gradient-boosting'
4752 and 'linear-svm'.
4853
54+ NOTE: `classifier` is deprecated from 0.2 and will be replaced in 0.4.
55+ Use `estimator` instead.
56+
57+ estimator : object, optional (default=KNeighborsClassifier())
58+ An estimator inherited from `sklearn.base.ClassifierMixin` and having
59+ an attribute `predict_proba`.
60+
4961 bootstrap : bool, optional (default=True)
5062 Whether to bootstrap the data before each iteration.
5163
5264 **kwargs : keywords
5365 The parameters associated with the classifier provided.
5466
67+ NOTE: `**kwargs` has been deprecated from 0.2 and will be replaced in
68+ 0.4. Use `estimator` object instead to pass parameters associated
69+ to an estimator.
70+
5571 Attributes
5672 ----------
5773 min_c_ : str or int
@@ -100,16 +116,97 @@ class BalanceCascade(BaseBinarySampler):
100116 """
101117
102118 def __init__ (self , ratio = 'auto' , return_indices = False , random_state = None ,
103- n_max_subset = None , classifier = 'knn' , bootstrap = True ,
104- ** kwargs ):
119+ n_max_subset = None , classifier = None , estimator = None ,
120+ bootstrap = True , ** kwargs ):
105121 super (BalanceCascade , self ).__init__ (ratio = ratio ,
106122 random_state = random_state )
107123 self .return_indices = return_indices
108124 self .classifier = classifier
125+ self .estimator = estimator
109126 self .n_max_subset = n_max_subset
110127 self .bootstrap = bootstrap
111128 self .kwargs = kwargs
112129
130+ def _validate_estimator (self ):
131+ """Private function to create the classifier"""
132+
133+ if self .classifier is not None :
134+ warnings .warn ('`classifier` will be replaced in version'
135+ ' 0.4. Use a `estimator` instead.' ,
136+ DeprecationWarning )
137+ self .estimator = self .classifier
138+
139+ if (self .estimator is not None and
140+ isinstance (self .estimator , ClassifierMixin ) and
141+ hasattr (self .estimator , 'predict' )):
142+ self .estimator_ = self .estimator
143+ elif self .estimator is None :
144+ self .estimator_ = KNeighborsClassifier ()
145+ # To be removed in 0.4
146+ elif (self .estimator is not None and
147+ isinstance (self .estimator , string_types )):
148+ warnings .warn ('`estimator` will be replaced in version'
149+ ' 0.4. Use a classifier object instead of a string.' ,
150+ DeprecationWarning )
151+ # Define the classifier to use
152+ if self .estimator == 'knn' :
153+ self .estimator_ = KNeighborsClassifier (
154+ ** self .kwargs )
155+ elif self .estimator == 'decision-tree' :
156+ from sklearn .tree import DecisionTreeClassifier
157+ self .estimator_ = DecisionTreeClassifier (
158+ random_state = self .random_state ,
159+ ** self .kwargs )
160+ elif self .estimator == 'random-forest' :
161+ from sklearn .ensemble import RandomForestClassifier
162+ self .estimator_ = RandomForestClassifier (
163+ random_state = self .random_state ,
164+ ** self .kwargs )
165+ elif self .estimator == 'adaboost' :
166+ from sklearn .ensemble import AdaBoostClassifier
167+ self .estimator_ = AdaBoostClassifier (
168+ random_state = self .random_state ,
169+ ** self .kwargs )
170+ elif self .estimator == 'gradient-boosting' :
171+ from sklearn .ensemble import GradientBoostingClassifier
172+ self .estimator_ = GradientBoostingClassifier (
173+ random_state = self .random_state ,
174+ ** self .kwargs )
175+ elif self .estimator == 'linear-svm' :
176+ from sklearn .svm import LinearSVC
177+ self .estimator_ = LinearSVC (random_state = self .random_state ,
178+ ** self .kwargs )
179+ else :
180+ raise NotImplementedError
181+ else :
182+ raise ValueError ('Invalid parameter `estimator`' )
183+
184+ self .logger .debug (self .estimator_ )
185+
186+ def fit (self , X , y ):
187+ """Find the classes statistics before to perform sampling.
188+
189+ Parameters
190+ ----------
191+ X : ndarray, shape (n_samples, n_features)
192+ Matrix containing the data which have to be sampled.
193+
194+ y : ndarray, shape (n_samples, )
195+ Corresponding label for each sample in X.
196+
197+ Returns
198+ -------
199+ self : object,
200+ Return self.
201+
202+ """
203+
204+ super (BalanceCascade , self ).fit (X , y )
205+
206+ self ._validate_estimator ()
207+
208+ return self
209+
113210 def _sample (self , X , y ):
114211 """Resample the dataset.
115212
@@ -135,42 +232,9 @@ def _sample(self, X, y):
135232
136233 """
137234
138- if self .classifier not in ESTIMATOR_KIND :
139- raise NotImplementedError
140-
141235 random_state = check_random_state (self .random_state )
142-
143- # Define the classifier to use
144- if self .classifier == 'knn' :
145- from sklearn .neighbors import KNeighborsClassifier
146- classifier = KNeighborsClassifier (
147- ** self .kwargs )
148- elif self .classifier == 'decision-tree' :
149- from sklearn .tree import DecisionTreeClassifier
150- classifier = DecisionTreeClassifier (
151- random_state = random_state ,
152- ** self .kwargs )
153- elif self .classifier == 'random-forest' :
154- from sklearn .ensemble import RandomForestClassifier
155- classifier = RandomForestClassifier (
156- random_state = random_state ,
157- ** self .kwargs )
158- elif self .classifier == 'adaboost' :
159- from sklearn .ensemble import AdaBoostClassifier
160- classifier = AdaBoostClassifier (
161- random_state = random_state ,
162- ** self .kwargs )
163- elif self .classifier == 'gradient-boosting' :
164- from sklearn .ensemble import GradientBoostingClassifier
165- classifier = GradientBoostingClassifier (
166- random_state = random_state ,
167- ** self .kwargs )
168- elif self .classifier == 'linear-svm' :
169- from sklearn .svm import LinearSVC
170- classifier = LinearSVC (random_state = random_state ,
171- ** self .kwargs )
172- else :
173- raise NotImplementedError
236+ support_sample_weight = has_fit_parameter (self .estimator_ ,
237+ "sample_weight" )
174238
175239 X_resampled = []
176240 y_resampled = []
@@ -185,6 +249,7 @@ def _sample(self, X, y):
185249 # return them later
186250 if self .return_indices :
187251 idx_min = np .flatnonzero (y == self .min_c_ )
252+ idx_maj = np .flatnonzero (y == self .maj_c_ )
188253
189254 # Condition to initiliase before the search
190255 b_subset_search = True
@@ -227,27 +292,42 @@ def _sample(self, X, y):
227292 X_resampled .append (x_data )
228293 y_resampled .append (y_data )
229294 if self .return_indices :
230- idx_under .append (np .concatenate ((idx_min , idx_sel_from_maj ),
295+ idx_under .append (np .concatenate ((idx_min ,
296+ idx_maj [idx_sel_from_maj ]),
231297 axis = 0 ))
232298
233- if (not (self .classifier == 'knn' or
234- self .classifier == 'linear-svm' ) and
235- self .bootstrap ):
236- # Apply a bootstrap on x_data
237- curr_sample_weight = np .ones ((y_data .size ,), dtype = np .float64 )
299+ # Get the indices of interest
300+ if self .bootstrap :
238301 indices = random_state .randint (0 , y_data .size , y_data .size )
239- sample_counts = np . bincount ( indices , minlength = y_data . size )
240- curr_sample_weight *= sample_counts
302+ else :
303+ indices = np . arange ( y_data . size )
241304
242- # Train the classifier using the current data
243- classifier .fit (x_data , y_data , curr_sample_weight )
305+ # Draw samples, using sample weights, and then fit
306+ if support_sample_weight :
307+ self .logger .debug ('Sample-weight is supported' )
308+ curr_sample_weight = np .ones ((y_data .size ,), dtype = np .float64 )
244309
310+ if self .bootstrap :
311+ self .logger .debug ('Go for a bootstrap' )
312+ sample_counts = np .bincount (indices , minlength = y_data .size )
313+ curr_sample_weight *= sample_counts
314+ else :
315+ self .logger .debug ('No bootstrap' )
316+ mask = np .zeros (y_data .size , dtype = np .bool )
317+ mask [indices ] = True
318+ not_indices_mask = ~ mask
319+ curr_sample_weight [not_indices_mask ] = 0
320+
321+ self .estimator_ .fit (x_data , y_data ,
322+ sample_weight = curr_sample_weight )
323+
324+ # Draw samples, using a mask, and then fit
245325 else :
246- # Train the classifier using the current data
247- classifier . fit (x_data , y_data )
326+ self . logger . debug ( 'Sample-weight is not supported' )
327+ self . estimator_ . fit (x_data [ indices ] , y_data [ indices ] )
248328
249329 # Predict using only the majority class
250- pred_label = classifier .predict (N_x [idx_sel_from_maj , :])
330+ pred_label = self . estimator_ .predict (N_x [idx_sel_from_maj , :])
251331
252332 # Basically let's find which sample have to be retained for the
253333 # next round
@@ -288,9 +368,8 @@ def _sample(self, X, y):
288368 X_resampled .append (x_data )
289369 y_resampled .append (y_data )
290370 if self .return_indices :
291- idx_under .append (np .concatenate ((idx_min ,
292- idx_sel_from_maj ),
293- axis = 0 ))
371+ idx_under .append (np .concatenate (
372+ (idx_min , idx_maj [idx_sel_from_maj ]), axis = 0 ))
294373
295374 self .logger .debug ('Creation of the subset #%s' , n_subsets )
296375
@@ -321,9 +400,8 @@ def _sample(self, X, y):
321400 X_resampled .append (x_data )
322401 y_resampled .append (y_data )
323402 if self .return_indices :
324- idx_under .append (np .concatenate ((idx_min ,
325- idx_sel_from_maj ),
326- axis = 0 ))
403+ idx_under .append (np .concatenate (
404+ (idx_min , idx_maj [idx_sel_from_maj ]), axis = 0 ))
327405 self .logger .debug ('Creation of the subset #%s' , n_subsets )
328406
329407 # We found a new subset, increase the counter
0 commit comments