@@ -114,7 +114,7 @@ class providing the base-class of operations.
114114from pandas .core .series import Series
115115from pandas .core .sorting import get_group_index_sorter
116116from pandas .core .util .numba_ import (
117- NUMBA_FUNC_CACHE ,
117+ get_jit_arguments ,
118118 maybe_use_numba ,
119119)
120120
@@ -1247,11 +1247,7 @@ def _resolve_numeric_only(self, numeric_only: bool | lib.NoDefault) -> bool:
12471247 # numba
12481248
12491249 @final
1250- def _numba_prep (self , func , data ):
1251- if not callable (func ):
1252- raise NotImplementedError (
1253- "Numba engine can only be used with a single function."
1254- )
1250+ def _numba_prep (self , data ):
12551251 ids , _ , ngroups = self .grouper .group_info
12561252 sorted_index = get_group_index_sorter (ids , ngroups )
12571253 sorted_ids = algorithms .take_nd (ids , sorted_index , allow_fill = False )
@@ -1271,7 +1267,6 @@ def _numba_agg_general(
12711267 self ,
12721268 func : Callable ,
12731269 engine_kwargs : dict [str , bool ] | None ,
1274- numba_cache_key_str : str ,
12751270 * aggregator_args ,
12761271 ):
12771272 """
@@ -1288,16 +1283,12 @@ def _numba_agg_general(
12881283 with self ._group_selection_context ():
12891284 data = self ._selected_obj
12901285 df = data if data .ndim == 2 else data .to_frame ()
1291- starts , ends , sorted_index , sorted_data = self ._numba_prep (func , df )
1286+ starts , ends , sorted_index , sorted_data = self ._numba_prep (df )
12921287 aggregator = executor .generate_shared_aggregator (
1293- func , engine_kwargs , numba_cache_key_str
1288+ func , ** get_jit_arguments ( engine_kwargs )
12941289 )
12951290 result = aggregator (sorted_data , starts , ends , 0 , * aggregator_args )
12961291
1297- cache_key = (func , numba_cache_key_str )
1298- if cache_key not in NUMBA_FUNC_CACHE :
1299- NUMBA_FUNC_CACHE [cache_key ] = aggregator
1300-
13011292 index = self .grouper .result_index
13021293 if data .ndim == 1 :
13031294 result_kwargs = {"name" : data .name }
@@ -1315,10 +1306,10 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13151306 to generate the indices of each group in the sorted data and then passes the
13161307 data and indices into a Numba jitted function.
13171308 """
1318- starts , ends , sorted_index , sorted_data = self ._numba_prep (func , data )
1319-
1309+ starts , ends , sorted_index , sorted_data = self ._numba_prep (data )
1310+ numba_ . validate_udf ( func )
13201311 numba_transform_func = numba_ .generate_numba_transform_func (
1321- kwargs , func , engine_kwargs
1312+ func , ** get_jit_arguments ( engine_kwargs , kwargs )
13221313 )
13231314 result = numba_transform_func (
13241315 sorted_data ,
@@ -1328,11 +1319,6 @@ def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13281319 len (data .columns ),
13291320 * args ,
13301321 )
1331-
1332- cache_key = (func , "groupby_transform" )
1333- if cache_key not in NUMBA_FUNC_CACHE :
1334- NUMBA_FUNC_CACHE [cache_key ] = numba_transform_func
1335-
13361322 # result values needs to be resorted to their original positions since we
13371323 # evaluated the data sorted by group
13381324 return result .take (np .argsort (sorted_index ), axis = 0 )
@@ -1346,9 +1332,11 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13461332 to generate the indices of each group in the sorted data and then passes the
13471333 data and indices into a Numba jitted function.
13481334 """
1349- starts , ends , sorted_index , sorted_data = self ._numba_prep (func , data )
1350-
1351- numba_agg_func = numba_ .generate_numba_agg_func (kwargs , func , engine_kwargs )
1335+ starts , ends , sorted_index , sorted_data = self ._numba_prep (data )
1336+ numba_ .validate_udf (func )
1337+ numba_agg_func = numba_ .generate_numba_agg_func (
1338+ func , ** get_jit_arguments (engine_kwargs , kwargs )
1339+ )
13521340 result = numba_agg_func (
13531341 sorted_data ,
13541342 sorted_index ,
@@ -1357,11 +1345,6 @@ def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs)
13571345 len (data .columns ),
13581346 * args ,
13591347 )
1360-
1361- cache_key = (func , "groupby_agg" )
1362- if cache_key not in NUMBA_FUNC_CACHE :
1363- NUMBA_FUNC_CACHE [cache_key ] = numba_agg_func
1364-
13651348 return result
13661349
13671350 # -----------------------------------------------------------------
@@ -1947,7 +1930,7 @@ def mean(
19471930 if maybe_use_numba (engine ):
19481931 from pandas .core ._numba .kernels import sliding_mean
19491932
1950- return self ._numba_agg_general (sliding_mean , engine_kwargs , "groupby_mean" )
1933+ return self ._numba_agg_general (sliding_mean , engine_kwargs )
19511934 else :
19521935 result = self ._cython_agg_general (
19531936 "mean" ,
@@ -2029,9 +2012,7 @@ def std(
20292012 if maybe_use_numba (engine ):
20302013 from pandas .core ._numba .kernels import sliding_var
20312014
2032- return np .sqrt (
2033- self ._numba_agg_general (sliding_var , engine_kwargs , "groupby_std" , ddof )
2034- )
2015+ return np .sqrt (self ._numba_agg_general (sliding_var , engine_kwargs , ddof ))
20352016 else :
20362017 return self ._get_cythonized_result (
20372018 libgroupby .group_var ,
@@ -2085,9 +2066,7 @@ def var(
20852066 if maybe_use_numba (engine ):
20862067 from pandas .core ._numba .kernels import sliding_var
20872068
2088- return self ._numba_agg_general (
2089- sliding_var , engine_kwargs , "groupby_var" , ddof
2090- )
2069+ return self ._numba_agg_general (sliding_var , engine_kwargs , ddof )
20912070 else :
20922071 if ddof == 1 :
20932072 numeric_only = self ._resolve_numeric_only (lib .no_default )
@@ -2180,7 +2159,6 @@ def sum(
21802159 return self ._numba_agg_general (
21812160 sliding_sum ,
21822161 engine_kwargs ,
2183- "groupby_sum" ,
21842162 )
21852163 else :
21862164 numeric_only = self ._resolve_numeric_only (numeric_only )
@@ -2221,9 +2199,7 @@ def min(
22212199 if maybe_use_numba (engine ):
22222200 from pandas .core ._numba .kernels import sliding_min_max
22232201
2224- return self ._numba_agg_general (
2225- sliding_min_max , engine_kwargs , "groupby_min" , False
2226- )
2202+ return self ._numba_agg_general (sliding_min_max , engine_kwargs , False )
22272203 else :
22282204 return self ._agg_general (
22292205 numeric_only = numeric_only ,
@@ -2244,9 +2220,7 @@ def max(
22442220 if maybe_use_numba (engine ):
22452221 from pandas .core ._numba .kernels import sliding_min_max
22462222
2247- return self ._numba_agg_general (
2248- sliding_min_max , engine_kwargs , "groupby_max" , True
2249- )
2223+ return self ._numba_agg_general (sliding_min_max , engine_kwargs , True )
22502224 else :
22512225 return self ._agg_general (
22522226 numeric_only = numeric_only ,
0 commit comments