@@ -324,7 +324,12 @@ def transform(self, X):
324324 check_is_fitted (self , "cluster_centers_" )
325325
326326 Y = self .cluster_centers_
327- return pairwise_distances (X , Y = Y , metric = self .metric )
327+ kwargs = {}
328+ if self .metric == "seuclidean" :
329+ kwargs ["V" ] = np .var (np .vstack ([X , Y ]), axis = 0 , ddof = 1 )
330+ DXY = pairwise_distances (X , Y = Y , metric = self .metric , ** kwargs )
331+
332+ return DXY
328333
329334 def predict (self , X ):
330335 """Predict the closest cluster for each sample in X.
@@ -350,10 +355,20 @@ def predict(self, X):
350355
351356 # Return data points to clusters based on which cluster assignment
352357 # yields the smallest distance
353- return pairwise_distances_argmin (
354- X , Y = self .cluster_centers_ , metric = self .metric
358+ kwargs = {}
359+ if self .metric == "seuclidean" :
360+ kwargs ["V" ] = np .var (
361+ np .vstack ([X , self .cluster_centers_ ]), axis = 0 , ddof = 1
362+ )
363+ pd_argmin = pairwise_distances_argmin (
364+ X ,
365+ Y = self .cluster_centers_ ,
366+ metric = self .metric ,
367+ metric_kwargs = kwargs ,
355368 )
356369
370+ return pd_argmin
371+
357372 def _compute_inertia (self , distances ):
358373 """Compute inertia of new samples. Inertia is defined as the sum of the
359374 sample distances to closest cluster centers.
0 commit comments