Skip to content

Commit f09268e

Browse files
authored
Add RandomState support for df sampling (#526)
* Add RandomState support for df sampling * Add typing and styling fixes * Add support for groupby and series * Remove arbitrary space
1 parent 93f3b24 commit f09268e

File tree

5 files changed

+19
-3
lines changed

5 files changed

+19
-3
lines changed

pandas-stubs/_typing.pyi

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,4 +373,13 @@ ValidationOptions: TypeAlias = Literal[
373373
"many_to_many",
374374
"m:m",
375375
]
376+
377+
RandomState: TypeAlias = Union[
378+
int,
379+
ArrayLike,
380+
np.random.Generator,
381+
np.random.BitGenerator,
382+
np.random.RandomState,
383+
]
384+
376385
__all__ = ["npt", "type_t"]

pandas-stubs/core/frame.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ from pandas._typing import (
9494
NaPosition,
9595
ParquetEngine,
9696
QuantileInterpolation,
97+
RandomState,
9798
ReadBuffer,
9899
Renamer,
99100
ReplaceMethod,
@@ -1916,7 +1917,7 @@ class DataFrame(NDFrame, OpsMixin):
19161917
frac: float | None = ...,
19171918
replace: _bool = ...,
19181919
weights: _str | ListLike | None = ...,
1919-
random_state: int | None = ...,
1920+
random_state: RandomState | None = ...,
19201921
axis: SeriesAxisType | None = ...,
19211922
ignore_index: _bool = ...,
19221923
) -> DataFrame: ...

pandas-stubs/core/groupby/generic.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ from pandas._typing import (
3333
AxisType,
3434
Level,
3535
ListLike,
36+
RandomState,
3637
Scalar,
3738
)
3839

@@ -302,7 +303,7 @@ class DataFrameGroupBy(GroupBy):
302303
frac: float | None = ...,
303304
replace: bool = ...,
304305
weights: ListLike | None = ...,
305-
random_state: int | None = ...,
306+
random_state: RandomState | None = ...,
306307
) -> DataFrame: ...
307308
def sem(self, ddof: int = ..., numeric_only: bool = ...) -> DataFrame: ...
308309
def shift(

pandas-stubs/core/series.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ from pandas._typing import (
100100
MaskType,
101101
NaPosition,
102102
QuantileInterpolation,
103+
RandomState,
103104
Renamer,
104105
ReplaceMethod,
105106
Scalar,
@@ -1021,7 +1022,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
10211022
frac: float | None = ...,
10221023
replace: _bool = ...,
10231024
weights: _str | _ListLike | np.ndarray | None = ...,
1024-
random_state: int | None = ...,
1025+
random_state: RandomState | None = ...,
10251026
axis: SeriesAxisType | None = ...,
10261027
ignore_index: _bool = ...,
10271028
) -> Series[S1]: ...

tests/test_frame.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ def test_types_sample() -> None:
231231
# GH 67
232232
check(assert_type(df.sample(frac=0.5), pd.DataFrame), pd.DataFrame)
233233
check(assert_type(df.sample(n=1), pd.DataFrame), pd.DataFrame)
234+
check(
235+
assert_type(df.sample(n=1, random_state=np.random.default_rng()), pd.DataFrame),
236+
pd.DataFrame,
237+
)
234238

235239

236240
def test_types_nlargest_nsmallest() -> None:

0 commit comments

Comments
 (0)