Skip to content

Commit 38e260f

Browse files
FIX _estimator_type in robust estimators (#108)
1 parent 55ad1f1 commit 38e260f

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

doc/changelog.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
Changelog
22
=========
33

4+
Unreleased
5+
----------
6+
7+
- Fix `_estimator_type` for :class:`~sklearn_extra.robust` estimators. Fix
8+
misbehavior of scikit-learn's :class:`~sklearn.model_selection.cross_val_score` and
9+
:class:`~sklearn.grid_search.GridSearchCV` for :class:`~sklearn_extra.robust.RobustWeightedClassifier`
410

511
Version 0.2.0
612
-------------

sklearn_extra/robust/robust_weighted_estimator.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,8 @@ def fit(self, X, y):
790790
sgd_args = self.sgd_args
791791

792792
# Define the base estimator
793+
X, y = self._validate_data(X, y, y_numeric=False)
794+
793795
base_robust_estimator_ = _RobustWeightedEstimator(
794796
SGDClassifier(**sgd_args, eta0=self.eta0),
795797
weighting=self.weighting,
@@ -874,10 +876,6 @@ def predict_proba(self):
874876
def _predict_proba(self, X):
875877
return self.base_estimator_.predict_proba(X)
876878

877-
@property
878-
def _estimator_type(self):
879-
return self.base_estimator._estimator_type
880-
881879
def score(self, X, y=None):
882880
"""Returns the score on the given data, using
883881
``base_estimator_.score``.
@@ -1103,6 +1101,8 @@ def fit(self, X, y):
11031101

11041102
# Define the base estimator
11051103

1104+
X, y = self._validate_data(X, y, y_numeric=True)
1105+
11061106
self.base_estimator_ = _RobustWeightedEstimator(
11071107
SGDRegressor(**sgd_args, eta0=self.eta0),
11081108
weighting=self.weighting,
@@ -1142,10 +1142,6 @@ def predict(self, X):
11421142
check_is_fitted(self, attributes=["base_estimator_"])
11431143
return self.base_estimator_.predict(X)
11441144

1145-
@property
1146-
def _estimator_type(self):
1147-
return self.base_estimator._estimator_type
1148-
11491145
def score(self, X, y=None):
11501146
"""Returns the score on the given data, using
11511147
``base_estimator_.score``.
@@ -1350,14 +1346,13 @@ def fit(self, X, y=None):
13501346
kmeans_args = {}
13511347
else:
13521348
kmeans_args = self.kmeans_args
1353-
X = check_array(
1349+
X = self._validate_data(
13541350
X,
13551351
accept_sparse="csr",
13561352
dtype=[np.float64, np.float32],
13571353
order="C",
13581354
accept_large_sparse=False,
13591355
)
1360-
13611356
self.base_estimator_ = _RobustWeightedEstimator(
13621357
MiniBatchKMeans(
13631358
self.n_clusters,
@@ -1404,10 +1399,6 @@ def predict(self, X):
14041399
check_is_fitted(self, attributes=["base_estimator_"])
14051400
return self.base_estimator_.predict(X)
14061401

1407-
@property
1408-
def _estimator_type(self):
1409-
return self.base_estimator._estimator_type
1410-
14111402
def score(self, X, y=None):
14121403
"""Returns the score on the given data, using
14131404
``base_estimator_.score``.

0 commit comments

Comments
 (0)