@@ -70,17 +70,33 @@ def _get_onedal_params(self, X, y=None, n_neighbors=None):
7070
7171 fptype = np .float64
7272
73- # _fit_method should be set by sklearnex level before calling oneDAL
73+ # Handle _fit_method: use if set by sklearnex, otherwise determine it ourselves
7474 if not hasattr (self , "_fit_method" ) or self ._fit_method is None :
75- raise ValueError (
76- "_fit_method must be set by sklearnex level before calling oneDAL. "
77- "This indicates improper usage - oneDAL neighbors should not be called directly."
78- )
75+ # Direct oneDAL usage - determine method ourselves
76+ method = getattr (self , "algorithm" , "auto" )
77+ n_samples , n_features = X .shape
78+
79+ if method in ["auto" , "ball_tree" ]:
80+ condition = (
81+ self .n_neighbors is not None and self .n_neighbors >= n_samples // 2
82+ )
83+ if getattr (self , "metric" , "minkowski" ) == "precomputed" or n_features > 15 or condition :
84+ fit_method = "brute"
85+ else :
86+ if getattr (self , "effective_metric_" , getattr (self , "metric" , "minkowski" )) == "euclidean" :
87+ fit_method = "kd_tree"
88+ else :
89+ fit_method = "brute"
90+ else :
91+ fit_method = method
92+ else :
93+ # Use the method set by sklearnex level
94+ fit_method = self ._fit_method
7995
8096 return {
8197 "fptype" : fptype ,
8298 "vote_weights" : "uniform" if weights == "uniform" else "distance" ,
83- "method" : self . _fit_method ,
99+ "method" : fit_method ,
84100 "radius" : self .radius ,
85101 "class_count" : class_count ,
86102 "neighbor_count" : self .n_neighbors if n_neighbors is None else n_neighbors ,
0 commit comments