@@ -1449,6 +1449,13 @@ cdef group_cummin_max(iu_64_floating_t[:, ::1] out,
14491449 """
14501450 cdef:
14511451 iu_64_floating_t[:, ::1 ] accum
1452+ Py_ssize_t i, j, N, K
1453+ iu_64_floating_t val, mval, na_val
1454+ uint8_t[:, ::1 ] seen_na
1455+ intp_t lab
1456+ bint na_possible
1457+ bint uses_mask = mask is not None
1458+ bint isna_entry
14521459
14531460 accum = np.empty((ngroups, (< object > values).shape[1 ]), dtype = values.dtype)
14541461 if iu_64_floating_t is int64_t:
@@ -1458,40 +1465,18 @@ cdef group_cummin_max(iu_64_floating_t[:, ::1] out,
14581465 else :
14591466 accum[:] = - np.inf if compute_max else np.inf
14601467
1461- if mask is not None :
1462- masked_cummin_max(out, values, mask, labels, accum, skipna, compute_max)
1463- else :
1464- cummin_max(out, values, labels, accum, skipna, is_datetimelike, compute_max)
1465-
1466-
1467- @ cython.boundscheck (False )
1468- @ cython.wraparound (False )
1469- cdef cummin_max(iu_64_floating_t[:, ::1 ] out,
1470- ndarray[iu_64_floating_t, ndim= 2 ] values,
1471- const intp_t[::1 ] labels,
1472- iu_64_floating_t[:, ::1 ] accum,
1473- bint skipna,
1474- bint is_datetimelike,
1475- bint compute_max):
1476- """
1477- Compute the cumulative minimum/maximum of columns of `values`, in row groups
1478- `labels`.
1479- """
1480- cdef:
1481- Py_ssize_t i, j, N, K
1482- iu_64_floating_t val, mval, na_val
1483- uint8_t[:, ::1 ] seen_na
1484- intp_t lab
1485- bint na_possible
1486-
1487- if iu_64_floating_t is float64_t or iu_64_floating_t is float32_t:
1468+ if uses_mask:
1469+ na_possible = True
1470+ # Will never be used, just to avoid uninitialized warning
1471+ na_val = 0
1472+ elif iu_64_floating_t is float64_t or iu_64_floating_t is float32_t:
14881473 na_val = NaN
14891474 na_possible = True
14901475 elif is_datetimelike:
14911476 na_val = NPY_NAT
14921477 na_possible = True
1493- # Will never be used, just to avoid uninitialized warning
14941478 else :
1479+ # Will never be used, just to avoid uninitialized warning
14951480 na_val = 0
14961481 na_possible = False
14971482
@@ -1505,56 +1490,21 @@ cdef cummin_max(iu_64_floating_t[:, ::1] out,
15051490 if lab < 0 :
15061491 continue
15071492 for j in range (K):
1493+
15081494 if not skipna and na_possible and seen_na[lab, j]:
1509- out[i, j] = na_val
1495+ if uses_mask:
1496+ mask[i, j] = 1 # FIXME: shouldn't alter inplace
1497+ else :
1498+ out[i, j] = na_val
15101499 else :
15111500 val = values[i, j]
1512- if not _treat_as_na(val, is_datetimelike):
1513- mval = accum[lab, j]
1514- if compute_max:
1515- if val > mval:
1516- accum[lab, j] = mval = val
1517- else :
1518- if val < mval:
1519- accum[lab, j] = mval = val
1520- out[i, j] = mval
1521- else :
1522- seen_na[lab, j] = 1
1523- out[i, j] = val
1524-
15251501
1526- @ cython.boundscheck (False )
1527- @ cython.wraparound (False )
1528- cdef masked_cummin_max(iu_64_floating_t[:, ::1 ] out,
1529- ndarray[iu_64_floating_t, ndim= 2 ] values,
1530- uint8_t[:, ::1 ] mask,
1531- const intp_t[::1 ] labels,
1532- iu_64_floating_t[:, ::1 ] accum,
1533- bint skipna,
1534- bint compute_max):
1535- """
1536- Compute the cumulative minimum/maximum of columns of `values`, in row groups
1537- `labels` with a masked algorithm.
1538- """
1539- cdef:
1540- Py_ssize_t i, j, N, K
1541- iu_64_floating_t val, mval
1542- uint8_t[:, ::1 ] seen_na
1543- intp_t lab
1502+ if uses_mask:
1503+ isna_entry = mask[i, j]
1504+ else :
1505+ isna_entry = _treat_as_na(val, is_datetimelike)
15441506
1545- N, K = (< object > values).shape
1546- seen_na = np.zeros((< object > accum).shape, dtype = np.uint8)
1547- with nogil:
1548- for i in range (N):
1549- lab = labels[i]
1550- if lab < 0 :
1551- continue
1552- for j in range (K):
1553- if not skipna and seen_na[lab, j]:
1554- mask[i, j] = 1
1555- else :
1556- if not mask[i, j]:
1557- val = values[i, j]
1507+ if not isna_entry:
15581508 mval = accum[lab, j]
15591509 if compute_max:
15601510 if val > mval:
@@ -1565,6 +1515,7 @@ cdef masked_cummin_max(iu_64_floating_t[:, ::1] out,
15651515 out[i, j] = mval
15661516 else :
15671517 seen_na[lab, j] = 1
1518+ out[i, j] = val
15681519
15691520
15701521@ cython.boundscheck (False )
0 commit comments