Skip to content

Commit b26be15

Browse files
glemaitrechkoar
authored andcommitted
Resolve #111 - Handle multiclass/binary targets
1 parent 0c46346 commit b26be15

40 files changed

+651
-59
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,4 @@ Functions
117117
:toctree: generated/
118118

119119
pipeline.make_pipeline
120+

doc/whats_new.rst

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,26 @@ Version 0.2
1212
Changelog
1313
---------
1414

15-
- Added support for bumpversion.
16-
- Added doctest in the documentation.
15+
New features
16+
~~~~~~~~~~~~
17+
1718
- Added AllKNN under sampling technique.
1819

20+
API changes summary
21+
~~~~~~~~~~~~~~~~~~~
22+
23+
- Two base classes :class:`BaseBinaryclassSampler` and :class:`BaseMulticlassSampler` have been created to handle the target type and raise warning in case of abnormality.
24+
25+
Enhancement
26+
~~~~~~~~~~~
27+
28+
- Added support for bumpversion.
29+
- Validate the type of target in binary samplers. A warning is raised for the moment.
30+
31+
Documentation changes
32+
~~~~~~~~~~~~~~~~~~~~~
33+
34+
- Added doctest in the documentation.
1935

2036
.. _changes_0_1:
2137

imblearn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
Module which provides methods to under-sample a dataset.
1515
under-sampling
1616
Module which provides methods to over-sample a dataset.
17+
pipeline
18+
Module which allowing to create pipeline with scikit-learn estimators.
1719
"""
1820

1921
from .version import _check_module_dependencies, __version__

imblearn/base.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from sklearn.base import BaseEstimator
1616
from sklearn.utils import check_X_y
17+
from sklearn.utils.multiclass import type_of_target
1718
from sklearn.externals import six
1819

1920
from six import string_types
@@ -27,7 +28,7 @@ class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):
2728
instead.
2829
"""
2930

30-
_estimator_type = "sampler"
31+
_estimator_type = 'sampler'
3132

3233
def __init__(self, ratio='auto'):
3334
"""Initialize this object and its instance variables.
@@ -226,3 +227,74 @@ def __setstate__(self, dict):
226227
logger = logging.getLogger(__name__)
227228
self.__dict__.update(dict)
228229
self.logger = logger
230+
231+
232+
class BaseBinarySampler(six.with_metaclass(ABCMeta, SamplerMixin)):
233+
"""Base class for all binary class sampler.
234+
235+
Warning: This class should not be used directly. Use derived classes
236+
instead.
237+
238+
"""
239+
240+
def fit(self, X, y):
241+
"""Find the classes statistics before to perform sampling.
242+
243+
Parameters
244+
----------
245+
X : ndarray, shape (n_samples, n_features)
246+
Matrix containing the data which have to be sampled.
247+
248+
y : ndarray, shape (n_samples, )
249+
Corresponding label for each sample in X.
250+
251+
Returns
252+
-------
253+
self : object,
254+
Return self.
255+
256+
"""
257+
258+
super(BaseBinarySampler, self).fit(X, y)
259+
260+
# Check that the target type is binary
261+
if not type_of_target(y) == 'binary':
262+
warnings.warn('The target type should be binary.')
263+
264+
return self
265+
266+
267+
class BaseMulticlassSampler(six.with_metaclass(ABCMeta, SamplerMixin)):
268+
"""Base class for all multiclass sampler.
269+
270+
Warning: This class should not be used directly. Use derived classes
271+
instead.
272+
273+
"""
274+
275+
def fit(self, X, y):
276+
"""Find the classes statistics before to perform sampling.
277+
278+
Parameters
279+
----------
280+
X : ndarray, shape (n_samples, n_features)
281+
Matrix containing the data which have to be sampled.
282+
283+
y : ndarray, shape (n_samples, )
284+
Corresponding label for each sample in X.
285+
286+
Returns
287+
-------
288+
self : object,
289+
Return self.
290+
291+
"""
292+
293+
super(BaseMulticlassSampler, self).fit(X, y)
294+
295+
# Check that the target type is either binary or multiclass
296+
if not (type_of_target(y) == 'binary' or
297+
type_of_target(y) == 'multiclass'):
298+
warnings.warn('The target type should be binary or multiclass.')
299+
300+
return self

imblearn/combine/smote_enn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from ..over_sampling import SMOTE
66
from ..under_sampling import EditedNearestNeighbours
7-
from ..base import SamplerMixin
7+
from ..base import BaseBinarySampler
88

99

10-
class SMOTEENN(SamplerMixin):
10+
class SMOTEENN(BaseBinarySampler):
1111
"""Class to perform over-sampling using SMOTE and cleaning using ENN.
1212
1313
Combine over- and under-sampling using SMOTE and Edited Nearest Neighbours.

imblearn/combine/smote_tomek.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
from ..over_sampling import SMOTE
77
from ..under_sampling import TomekLinks
8-
from ..base import SamplerMixin
8+
from ..base import BaseBinarySampler
99

1010

11-
class SMOTETomek(SamplerMixin):
11+
class SMOTETomek(BaseBinarySampler):
1212
"""Class to perform over-sampling using SMOTE and cleaning using
1313
Tomek links.
1414

imblearn/combine/tests/test_smote_enn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,18 @@ def test_sample_wrong_X():
131131
sm.fit(X, Y)
132132
assert_raises(RuntimeError, sm.sample, np.random.random((100, 40)),
133133
np.array([0] * 50 + [1] * 50))
134+
135+
136+
def test_senn_multiclass_error():
137+
""" Test either if an error is raised when the target are not binary
138+
type. """
139+
140+
# continuous case
141+
y = np.linspace(0, 1, 5000)
142+
sm = SMOTEENN(random_state=RND_SEED)
143+
assert_warns(UserWarning, sm.fit, X, y)
144+
145+
# multiclass case
146+
y = np.array([0] * 2000 + [1] * 2000 + [2] * 1000)
147+
sm = SMOTEENN(random_state=RND_SEED)
148+
assert_warns(UserWarning, sm.fit, X, y)

imblearn/combine/tests/test_smote_tomek.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,18 @@ def test_sample_wrong_X():
131131
sm.fit(X, Y)
132132
assert_raises(RuntimeError, sm.sample, np.random.random((100, 40)),
133133
np.array([0] * 50 + [1] * 50))
134+
135+
136+
def test_multiclass_error():
137+
""" Test either if an error is raised when the target are not binary
138+
type. """
139+
140+
# continuous case
141+
y = np.linspace(0, 1, 5000)
142+
sm = SMOTETomek(random_state=RND_SEED)
143+
assert_warns(UserWarning, sm.fit, X, y)
144+
145+
# multiclass case
146+
y = np.array([0] * 2000 + [1] * 2000 + [2] * 1000)
147+
sm = SMOTETomek(random_state=RND_SEED)
148+
assert_warns(UserWarning, sm.fit, X, y)

imblearn/ensemble/balance_cascade.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
from sklearn.utils import check_random_state
77

8-
from ..base import SamplerMixin
8+
from ..base import BaseBinarySampler
99

1010

1111
ESTIMATOR_KIND = ('knn', 'decision-tree', 'random-forest', 'adaboost',
1212
'gradient-boosting', 'linear-svm')
1313

1414

15-
class BalanceCascade(SamplerMixin):
15+
class BalanceCascade(BaseBinarySampler):
1616
"""Create an ensemble of balanced sets by iteratively under-sampling the
1717
imbalanced dataset using an estimator.
1818
@@ -100,6 +100,7 @@ class BalanceCascade(SamplerMixin):
100100
April 2009.
101101
102102
"""
103+
103104
def __init__(self, ratio='auto', return_indices=False, random_state=None,
104105
n_max_subset=None, classifier='knn', bootstrap=True,
105106
**kwargs):

imblearn/ensemble/easy_ensemble.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
import numpy as np
55

6-
from ..base import SamplerMixin
6+
from ..base import BaseMulticlassSampler
77
from ..under_sampling import RandomUnderSampler
88

99

10-
class EasyEnsemble(SamplerMixin):
10+
class EasyEnsemble(BaseMulticlassSampler):
1111
"""Create an ensemble sets by iteratively applying random under-sampling.
1212
1313
This method iteratively select a random subset and make an ensemble of the
@@ -56,6 +56,8 @@ class EasyEnsemble(SamplerMixin):
5656
-----
5757
The method is described in [1]_.
5858
59+
This method supports multiclass target type.
60+
5961
Examples
6062
--------
6163

0 commit comments

Comments
 (0)