Skip to content

Commit 02da9e9

Browse files
committed
fix: fix test
1 parent e9e7306 commit 02da9e9

File tree

2 files changed

+12
-28
lines changed

2 files changed

+12
-28
lines changed

onedal/neighbors/neighbors.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,33 +70,17 @@ def _get_onedal_params(self, X, y=None, n_neighbors=None):
7070

7171
fptype = np.float64
7272

73-
# Handle _fit_method: use if set by sklearnex, otherwise determine it ourselves
73+
# _fit_method should be set by sklearnex level before calling oneDAL
7474
if not hasattr(self, "_fit_method") or self._fit_method is None:
75-
# Direct oneDAL usage - determine method ourselves
76-
method = getattr(self, "algorithm", "auto")
77-
n_samples, n_features = X.shape
78-
79-
if method in ["auto", "ball_tree"]:
80-
condition = (
81-
self.n_neighbors is not None and self.n_neighbors >= n_samples // 2
82-
)
83-
if getattr(self, "metric", "minkowski") == "precomputed" or n_features > 15 or condition:
84-
fit_method = "brute"
85-
else:
86-
if getattr(self, "effective_metric_", getattr(self, "metric", "minkowski")) == "euclidean":
87-
fit_method = "kd_tree"
88-
else:
89-
fit_method = "brute"
90-
else:
91-
fit_method = method
92-
else:
93-
# Use the method set by sklearnex level
94-
fit_method = self._fit_method
75+
raise ValueError(
76+
"_fit_method must be set by sklearnex level before calling oneDAL. "
77+
"This indicates improper usage - oneDAL neighbors should not be called directly."
78+
)
9579

9680
return {
9781
"fptype": fptype,
9882
"vote_weights": "uniform" if weights == "uniform" else "distance",
99-
"method": fit_method,
83+
"method": self._fit_method,
10084
"radius": self.radius,
10185
"class_count": class_count,
10286
"neighbor_count": self.n_neighbors if n_neighbors is None else n_neighbors,

onedal/neighbors/tests/test_knn_classification.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
from numpy.testing import assert_array_equal
2020
from sklearn import datasets
2121

22-
from onedal.neighbors import KNeighborsClassifier
22+
from sklearnex.neighbors import KNeighborsClassifier
2323
from onedal.tests.utils._device_selection import get_queues
2424

2525

2626
@pytest.mark.parametrize("queue", get_queues())
2727
def test_iris(queue):
2828
iris = datasets.load_iris()
29-
clf = KNeighborsClassifier(2).fit(iris.data, iris.target, queue=queue)
30-
assert clf.score(iris.data, iris.target, queue=queue) > 0.9
29+
clf = KNeighborsClassifier(2).fit(iris.data, iris.target)
30+
assert clf.score(iris.data, iris.target) > 0.9
3131
assert_array_equal(clf.classes_, np.sort(clf.classes_))
3232

3333

@@ -36,14 +36,14 @@ def test_pickle(queue):
3636
if queue and queue.sycl_device.is_gpu:
3737
pytest.skip("KNN classifier pickling for the GPU sycl_queue is buggy.")
3838
iris = datasets.load_iris()
39-
clf = KNeighborsClassifier(2).fit(iris.data, iris.target, queue=queue)
40-
expected = clf.predict(iris.data, queue=queue)
39+
clf = KNeighborsClassifier(2).fit(iris.data, iris.target)
40+
expected = clf.predict(iris.data)
4141

4242
import pickle
4343

4444
dump = pickle.dumps(clf)
4545
clf2 = pickle.loads(dump)
4646

4747
assert type(clf2) == clf.__class__
48-
result = clf2.predict(iris.data, queue=queue)
48+
result = clf2.predict(iris.data)
4949
assert_array_equal(expected, result)

0 commit comments

Comments
 (0)