Skip to content

Commit 84e6af4

Browse files
authored
Merge pull request #115 from dvro/datasets
[MRG] adding make_imbalance function
2 parents b26be15 + a7c6158 commit 84e6af4

File tree

5 files changed

+246
-0
lines changed

5 files changed

+246
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
===========================
3+
make_imbalance function
4+
===========================
5+
6+
An illustration of the make_imbalance function
7+
8+
"""
9+
10+
print(__doc__)
11+
12+
import numpy as np
13+
14+
import matplotlib.pyplot as plt
15+
import seaborn as sns
16+
sns.set()
17+
18+
# Define some color for the plotting
19+
almost_black = '#262626'
20+
palette = sns.color_palette()
21+
22+
from sklearn.datasets import make_moons
23+
from imblearn.datasets import make_imbalance
24+
25+
26+
# Generate the dataset
27+
X, y = make_moons(n_samples=200, shuffle=True, noise=0.5, random_state=10)
28+
29+
# Two subplots, unpack the axes array immediately
30+
f, axs = plt.subplots(2, 3)
31+
32+
axs = [a for ax in axs for a in ax]
33+
34+
axs[0].scatter(X[y == 0, 0], X[y == 0, 1], label="Class #0",
35+
alpha=0.5, edgecolor=almost_black, facecolor=palette[0],
36+
linewidth=0.15)
37+
axs[0].scatter(X[y == 1, 0], X[y == 1, 1], label="Class #1",
38+
alpha=0.5, edgecolor=almost_black, facecolor=palette[2],
39+
linewidth=0.15)
40+
axs[0].set_title('Original set')
41+
42+
ratios = [0.9, 0.75, 0.5, 0.25, 0.1]
43+
for i, ratio in enumerate(ratios, start=1):
44+
ax = axs[i]
45+
46+
X_, y_ = make_imbalance(X, y, ratio=ratio, min_c_=1)
47+
48+
ax.scatter(X_[y_ == 0, 0], X_[y_ == 0, 1], label="Class #0",
49+
alpha=0.5, edgecolor=almost_black, facecolor=palette[0],
50+
linewidth=0.15)
51+
ax.scatter(X_[y_ == 1, 0], X_[y_ == 1, 1], label="Class #1",
52+
alpha=0.5, edgecolor=almost_black, facecolor=palette[2],
53+
linewidth=0.15)
54+
ax.set_title('make_imbalance ratio ({})'.format(ratio))
55+
56+
plt.show()

imblearn/datasets/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
The :mod:`imblearn.datasets` provides methods to generate
3+
imbalanced data.
4+
"""
5+
6+
from .imbalance import make_imbalance
7+
8+
__all__ = ['make_imbalance']

imblearn/datasets/imbalance.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Transform a dataset into an imbalanced dataset."""
2+
3+
import numpy as np
4+
5+
from collections import Counter
6+
7+
from sklearn.utils import check_X_y
8+
from sklearn.utils import check_random_state
9+
10+
def make_imbalance(X, y, ratio, min_c_=None, random_state=None):
11+
"""Turns a dataset into an imbalanced dataset at specific ratio.
12+
A simple toy dataset to visualize clustering and classification
13+
algorithms.
14+
15+
Parameters
16+
----------
17+
X : ndarray, shape (n_samples, n_features)
18+
Matrix containing the data to be imbalanced.
19+
20+
y : ndarray, shape (n_samples, )
21+
Corresponding label for each sample in X.
22+
23+
ratio : float,
24+
The desired ratio given by the number of samples in
25+
the minority class over the the number of samples in
26+
the majority class.
27+
28+
min_c_ : str or int, optional (default=None)
29+
The identifier of the class to be the minority class.
30+
If None, min_c_ is set to be the current minority class.
31+
32+
random_state : int, RandomState instance or None, optional (default=None)
33+
If int, random_state is the seed used by the random number generator;
34+
If RandomState instance, random_state is the random number generator;
35+
If None, the random number generator is the RandomState instance used
36+
by np.random.
37+
38+
Returns
39+
-------
40+
X_resampled : ndarray, shape (n_samples_new, n_features)
41+
The array containing the imbalanced data.
42+
43+
y_resampled : ndarray, shape (n_samples_new)
44+
The corresponding label of `X_resampled`
45+
"""
46+
if ratio <= 0.0 or ratio >= 1.0:
47+
raise ValueError('ratio value must be such that 0.0 < ratio < 1.0')
48+
49+
X, y = check_X_y(X, y)
50+
51+
random_state = check_random_state(random_state)
52+
53+
stats_c_ = Counter(y)
54+
55+
if min_c_ is None:
56+
min_c_ = min(stats_c_, key=stats_c_.get)
57+
58+
n_min_samples = int(np.count_nonzero(y != min_c_) * ratio)
59+
if n_min_samples > stats_c_[min_c_]:
60+
raise ValueError('Current imbalance ratio of data is lower than desired ratio!')
61+
if n_min_samples == 0:
62+
raise ValueError('Not enough samples for desired ratio!')
63+
64+
mask = y == min_c_
65+
66+
idx_maj = np.where(~mask)[0]
67+
idx_min = np.where(mask)[0]
68+
idx_min = random_state.choice(idx_min, size=n_min_samples, replace=False)
69+
idx = np.concatenate((idx_min, idx_maj), axis=0)
70+
71+
X_resampled, y_resampled = X[idx,:], y[idx]
72+
73+
return X_resampled, y_resampled
74+

imblearn/datasets/tests/__init__.py

Whitespace-only changes.
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Test the module easy ensemble."""
2+
from __future__ import print_function
3+
4+
import numpy as np
5+
from numpy.testing import assert_raises
6+
from numpy.testing import assert_equal
7+
8+
from collections import Counter
9+
10+
from imblearn.datasets import make_imbalance
11+
12+
# Generate a global dataset to use
13+
X = np.random.random((1000, 2))
14+
Y = np.zeros(1000)
15+
Y[500:] = 1
16+
17+
def test_make_imbalance_bad_ratio():
18+
"""Test either if an error is raised with bad ratio
19+
argument"""
20+
min_c_ = 1
21+
22+
# Define a zero ratio
23+
ratio = 0.0
24+
assert_raises(ValueError, make_imbalance, X, Y, ratio, min_c_)
25+
26+
# Define a negative ratio
27+
ratio = -2.0
28+
assert_raises(ValueError, make_imbalance, X, Y, ratio, min_c_)
29+
30+
# Define a ratio greater than 1
31+
ratio = 2.0
32+
assert_raises(ValueError, make_imbalance, X, Y, ratio, min_c_)
33+
34+
# Define ratio as a list which is not supported
35+
ratio = [.5, .5]
36+
assert_raises(ValueError, make_imbalance, X, Y, ratio, min_c_)
37+
38+
39+
def test_make_imbalance_invalid_ratio():
40+
"""Test either if error is raised with higher ratio
41+
than current ratio."""
42+
43+
y_ = np.zeros((X.shape[0], ))
44+
y_[0] = 1
45+
46+
ratio = 0.5
47+
assert_raises(ValueError, make_imbalance, X, y_, ratio)
48+
49+
def test_make_imbalance_single_class():
50+
"""Test either if an error when there is a single class"""
51+
y_ = np.zeros((X.shape[0], ))
52+
ratio = 0.5
53+
assert_raises(ValueError, make_imbalance, X, y_, ratio)
54+
55+
def test_make_imbalance_1():
56+
"""Test make_imbalance"""
57+
X_, y_ = make_imbalance(X, Y, ratio=0.5, min_c_=1)
58+
counter = Counter(y_)
59+
assert_equal(counter[0], 500)
60+
assert_equal(counter[1], 250)
61+
assert(np.all([X_i in X for X_i in X_]))
62+
63+
def test_make_imbalance_2():
64+
"""Test make_imbalance"""
65+
X_, y_ = make_imbalance(X, Y, ratio=0.25, min_c_=1)
66+
counter = Counter(y_)
67+
assert_equal(counter[0], 500)
68+
assert_equal(counter[1], 125)
69+
assert(np.all([X_i in X for X_i in X_]))
70+
71+
def test_make_imbalance_3():
72+
"""Test make_imbalance"""
73+
X_, y_ = make_imbalance(X, Y, ratio=0.1, min_c_=1)
74+
counter = Counter(y_)
75+
assert_equal(counter[0], 500)
76+
assert_equal(counter[1], 50)
77+
assert(np.all([X_i in X for X_i in X_]))
78+
79+
def test_make_imbalance_4():
80+
"""Test make_imbalance"""
81+
X_, y_ = make_imbalance(X, Y, ratio=0.01, min_c_=1)
82+
counter = Counter(y_)
83+
assert_equal(counter[0], 500)
84+
assert_equal(counter[1], 5)
85+
assert(np.all([X_i in X for X_i in X_]))
86+
87+
def test_make_imbalance_5():
88+
"""Test make_imbalance"""
89+
X_, y_ = make_imbalance(X, Y, ratio=0.01, min_c_=0)
90+
counter = Counter(y_)
91+
assert_equal(counter[1], 500)
92+
assert_equal(counter[0], 5)
93+
assert(np.all([X_i in X for X_i in X_]))
94+
95+
def test_make_imbalance_multiclass():
96+
"""Test make_imbalance with multiclass data"""
97+
# Make y to be multiclass
98+
y_ = np.zeros(1000)
99+
y_[100:500] = 1
100+
y_[500:] = 2
101+
102+
# Resample the data
103+
X_, y_ = make_imbalance(X, y_, ratio=0.1, min_c_=0)
104+
counter = Counter(y_)
105+
assert_equal(counter[0], 90)
106+
assert_equal(counter[1], 400)
107+
assert_equal(counter[2], 500)
108+
assert(np.all([X_i in X for X_i in X_]))

0 commit comments

Comments
 (0)