Skip to content

Commit 5be10c1

Browse files
TimotheeMathieurth
andauthored
Fix Warning for Kmedoids with seuclidean metric (#114)
Co-authored-by: Roman Yurchak <rth.yurchak@gmail.com>
1 parent b7c7115 commit 5be10c1

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

sklearn_extra/cluster/_k_medoids.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,12 @@ def transform(self, X):
324324
check_is_fitted(self, "cluster_centers_")
325325

326326
Y = self.cluster_centers_
327-
return pairwise_distances(X, Y=Y, metric=self.metric)
327+
kwargs = {}
328+
if self.metric == "seuclidean":
329+
kwargs["V"] = np.var(np.vstack([X, Y]), axis=0, ddof=1)
330+
DXY = pairwise_distances(X, Y=Y, metric=self.metric, **kwargs)
331+
332+
return DXY
328333

329334
def predict(self, X):
330335
"""Predict the closest cluster for each sample in X.
@@ -350,10 +355,20 @@ def predict(self, X):
350355

351356
# Return data points to clusters based on which cluster assignment
352357
# yields the smallest distance
353-
return pairwise_distances_argmin(
354-
X, Y=self.cluster_centers_, metric=self.metric
358+
kwargs = {}
359+
if self.metric == "seuclidean":
360+
kwargs["V"] = np.var(
361+
np.vstack([X, self.cluster_centers_]), axis=0, ddof=1
362+
)
363+
pd_argmin = pairwise_distances_argmin(
364+
X,
365+
Y=self.cluster_centers_,
366+
metric=self.metric,
367+
metric_kwargs=kwargs,
355368
)
356369

370+
return pd_argmin
371+
357372
def _compute_inertia(self, distances):
358373
"""Compute inertia of new samples. Inertia is defined as the sum of the
359374
sample distances to closest cluster centers.

sklearn_extra/cluster/tests/test_k_medoids.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,12 @@ def test_build():
347347
ske.fit(diss)
348348
assert ske.inertia_ <= 230
349349
assert len(np.unique(ske.labels_)) == 20
350+
351+
352+
def test_seuclidean():
353+
with pytest.warns(None) as record:
354+
km = KMedoids(2, metric="seuclidean", method="pam")
355+
km.fit(np.array([0, 0, 0, 1]).reshape((4, 1)))
356+
km.predict(np.array([0, 0, 0, 1]).reshape((4, 1)))
357+
km.transform(np.array([0, 0, 0, 1]).reshape((4, 1)))
358+
assert len(record) == 0

0 commit comments

Comments
 (0)