Skip to content

Commit e9e7306

Browse files
committed
fix: add fit emthod logic in onedla
1 parent debfcdf commit e9e7306

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

onedal/neighbors/neighbors.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,33 @@ def _get_onedal_params(self, X, y=None, n_neighbors=None):
7070

7171
fptype = np.float64
7272

73-
# _fit_method should be set by sklearnex level before calling oneDAL
73+
# Handle _fit_method: use if set by sklearnex, otherwise determine it ourselves
7474
if not hasattr(self, "_fit_method") or self._fit_method is None:
75-
raise ValueError(
76-
"_fit_method must be set by sklearnex level before calling oneDAL. "
77-
"This indicates improper usage - oneDAL neighbors should not be called directly."
78-
)
75+
# Direct oneDAL usage - determine method ourselves
76+
method = getattr(self, "algorithm", "auto")
77+
n_samples, n_features = X.shape
78+
79+
if method in ["auto", "ball_tree"]:
80+
condition = (
81+
self.n_neighbors is not None and self.n_neighbors >= n_samples // 2
82+
)
83+
if getattr(self, "metric", "minkowski") == "precomputed" or n_features > 15 or condition:
84+
fit_method = "brute"
85+
else:
86+
if getattr(self, "effective_metric_", getattr(self, "metric", "minkowski")) == "euclidean":
87+
fit_method = "kd_tree"
88+
else:
89+
fit_method = "brute"
90+
else:
91+
fit_method = method
92+
else:
93+
# Use the method set by sklearnex level
94+
fit_method = self._fit_method
7995

8096
return {
8197
"fptype": fptype,
8298
"vote_weights": "uniform" if weights == "uniform" else "distance",
83-
"method": self._fit_method,
99+
"method": fit_method,
84100
"radius": self.radius,
85101
"class_count": class_count,
86102
"neighbor_count": self.n_neighbors if n_neighbors is None else n_neighbors,

0 commit comments

Comments
 (0)