From cd5f1eb156f28dbef96b3bceb585eabdc4ff9494 Mon Sep 17 00:00:00 2001 From: Rob Bowden Date: Mon, 12 Aug 2019 04:29:53 -0400 Subject: [PATCH] BUG: Fix identification of deserialized np.nan --- missingpy/knnimpute.py | 8 ++++---- missingpy/missforest.py | 9 ++++----- missingpy/pairwise_external.py | 4 ++-- missingpy/utils.py | 4 +++- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/missingpy/knnimpute.py b/missingpy/knnimpute.py index a7f3d1c..ea2ae17 100644 --- a/missingpy/knnimpute.py +++ b/missingpy/knnimpute.py @@ -17,6 +17,8 @@ from .pairwise_external import _get_mask from .pairwise_external import _MASKED_METRICS +from .utils import is_nan + __all__ = [ 'KNNImputer', ] @@ -193,8 +195,7 @@ def fit(self, X, y=None): """ # Check data integrity and calling arguments - force_all_finite = False if self.missing_values in ["NaN", - np.nan] else True + force_all_finite = not is_nan(self.missing_values) if not force_all_finite: if self.metric not in _MASKED_METRICS and not callable( self.metric): @@ -250,8 +251,7 @@ def transform(self, X): """ check_is_fitted(self, ["fitted_X_", "statistics_"]) - force_all_finite = False if self.missing_values in ["NaN", - np.nan] else True + force_all_finite = not is_nan(self.missing_values) X = check_array(X, accept_sparse=False, dtype=FLOAT_DTYPES, force_all_finite=force_all_finite, copy=self.copy) diff --git a/missingpy/missforest.py b/missingpy/missforest.py index d0d2843..f1abe7d 100644 --- a/missingpy/missforest.py +++ b/missingpy/missforest.py @@ -13,6 +13,8 @@ from .pairwise_external import _get_mask +from .utils import is_nan + __all__ = [ 'MissForest', ] @@ -434,9 +436,7 @@ def fit(self, X, y=None, cat_vars=None): """ # Check data integrity and calling arguments - force_all_finite = False if self.missing_values in ["NaN", - np.nan] else True - + force_all_finite = not is_nan(self.missing_values) X = check_array(X, accept_sparse=False, dtype=np.float64, force_all_finite=force_all_finite, copy=self.copy) @@ -499,8 +499,7 @@ def transform(self, X): check_is_fitted(self, ["cat_vars_", "num_vars_", "statistics_"]) # Check data integrity - force_all_finite = False if self.missing_values in ["NaN", - np.nan] else True + force_all_finite = not is_nan(self.missing_values) X = check_array(X, accept_sparse=False, dtype=np.float64, force_all_finite=force_all_finite, copy=self.copy) diff --git a/missingpy/pairwise_external.py b/missingpy/pairwise_external.py index f81c207..1b55a7f 100644 --- a/missingpy/pairwise_external.py +++ b/missingpy/pairwise_external.py @@ -48,7 +48,7 @@ from sklearn.metrics.pairwise import _parallel_pairwise from sklearn.utils import check_array -from .utils import masked_euclidean_distances +from .utils import is_nan, masked_euclidean_distances _MASKED_METRICS = ['masked_euclidean'] _VALID_METRICS += ['masked_euclidean'] @@ -56,7 +56,7 @@ def _get_mask(X, value_to_mask): """Compute the boolean mask X == missing_values.""" - if value_to_mask == "NaN" or np.isnan(value_to_mask): + if is_nan(value_to_mask): return np.isnan(X) else: return X == value_to_mask diff --git a/missingpy/utils.py b/missingpy/utils.py index f045710..1f8333d 100644 --- a/missingpy/utils.py +++ b/missingpy/utils.py @@ -4,6 +4,8 @@ import numpy as np +def is_nan(n): + return n == "NaN" or isinstance(n, float) and np.isnan(n) def masked_euclidean_distances(X, Y=None, squared=False, missing_values="NaN", copy=True): @@ -92,7 +94,7 @@ def masked_euclidean_distances(X, Y=None, squared=False, raise ValueError("One or more rows only contain missing values.") # else: - if missing_values not in ["NaN", np.nan] and ( + if not is_nan(missing_values) and ( np.any(np.isnan(X)) or (Y is not X and np.any(np.isnan(Y)))): raise ValueError( "NaN values present but missing_value = {0}".format(