@@ -709,19 +709,13 @@ def test_equal(
709709 _left = ah .asarray (_left , dtype = promoted_dtype )
710710 _right = ah .asarray (_right , dtype = promoted_dtype )
711711
712- if dh .is_int_dtype (promoted_dtype ):
713- scalar_func = int
714- elif dh .is_float_dtype (promoted_dtype ):
715- scalar_func = float
716- else :
717- scalar_func = bool
718-
712+ scalar_type = dh .get_scalar_type (promoted_dtype )
719713 for idx in ah .ndindex (shape ):
720714 x1_idx = _left [idx ]
721715 x2_idx = _right [idx ]
722716 out_idx = out [idx ]
723717 assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
724- assert bool (out_idx ) == (scalar_func (x1_idx ) == scalar_func (x2_idx ))
718+ assert bool (out_idx ) == (scalar_type (x1_idx ) == scalar_type (x2_idx ))
725719
726720
727721@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -840,18 +834,13 @@ def test_greater(
840834 _left = ah .asarray (_left , dtype = promoted_dtype )
841835 _right = ah .asarray (_right , dtype = promoted_dtype )
842836
843- if dh .is_int_dtype (promoted_dtype ):
844- scalar_func = int
845- elif dh .is_float_dtype (promoted_dtype ):
846- scalar_func = float
847- else :
848- scalar_func = bool
837+ scalar_type = dh .get_scalar_type (promoted_dtype )
849838 for idx in ah .ndindex (shape ):
850839 out_idx = out [idx ]
851840 x1_idx = _left [idx ]
852841 x2_idx = _right [idx ]
853842 assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
854- assert bool (out_idx ) == (scalar_func (x1_idx ) > scalar_func (x2_idx ))
843+ assert bool (out_idx ) == (scalar_type (x1_idx ) > scalar_type (x2_idx ))
855844
856845
857846@pytest .mark .parametrize (
@@ -886,18 +875,13 @@ def test_greater_equal(
886875 _left = ah .asarray (_left , dtype = promoted_dtype )
887876 _right = ah .asarray (_right , dtype = promoted_dtype )
888877
889- if dh .is_int_dtype (promoted_dtype ):
890- scalar_func = int
891- elif dh .is_float_dtype (promoted_dtype ):
892- scalar_func = float
893- else :
894- scalar_func = bool
878+ scalar_type = dh .get_scalar_type (promoted_dtype )
895879 for idx in ah .ndindex (shape ):
896880 out_idx = out [idx ]
897881 x1_idx = _left [idx ]
898882 x2_idx = _right [idx ]
899883 assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
900- assert bool (out_idx ) == (scalar_func (x1_idx ) >= scalar_func (x2_idx ))
884+ assert bool (out_idx ) == (scalar_type (x1_idx ) >= scalar_type (x2_idx ))
901885
902886
903887@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
@@ -983,19 +967,13 @@ def test_less(
983967 _left = ah .asarray (_left , dtype = promoted_dtype )
984968 _right = ah .asarray (_right , dtype = promoted_dtype )
985969
986- if dh .is_int_dtype (promoted_dtype ):
987- scalar_func = int
988- elif dh .is_float_dtype (promoted_dtype ):
989- scalar_func = float
990- else :
991- scalar_func = bool
992-
970+ scalar_type = dh .get_scalar_type (promoted_dtype )
993971 for idx in ah .ndindex (shape ):
994972 x1_idx = _left [idx ]
995973 x2_idx = _right [idx ]
996974 out_idx = out [idx ]
997975 assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
998- assert bool (out_idx ) == (scalar_func (x1_idx ) < scalar_func (x2_idx ))
976+ assert bool (out_idx ) == (scalar_type (x1_idx ) < scalar_type (x2_idx ))
999977
1000978
1001979@pytest .mark .parametrize (
@@ -1030,19 +1008,13 @@ def test_less_equal(
10301008 _left = ah .asarray (_left , dtype = promoted_dtype )
10311009 _right = ah .asarray (_right , dtype = promoted_dtype )
10321010
1033- if dh .is_int_dtype (promoted_dtype ):
1034- scalar_func = int
1035- elif dh .is_float_dtype (promoted_dtype ):
1036- scalar_func = float
1037- else :
1038- scalar_func = bool
1039-
1011+ scalar_type = dh .get_scalar_type (promoted_dtype )
10401012 for idx in ah .ndindex (shape ):
10411013 x1_idx = _left [idx ]
10421014 x2_idx = _right [idx ]
10431015 out_idx = out [idx ]
10441016 assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
1045- assert bool (out_idx ) == (scalar_func (x1_idx ) <= scalar_func (x2_idx ))
1017+ assert bool (out_idx ) == (scalar_type (x1_idx ) <= scalar_type (x2_idx ))
10461018
10471019
10481020@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -1241,19 +1213,13 @@ def test_not_equal(
12411213 _left = ah .asarray (_left , dtype = promoted_dtype )
12421214 _right = ah .asarray (_right , dtype = promoted_dtype )
12431215
1244- if dh .is_int_dtype (promoted_dtype ):
1245- scalar_func = int
1246- elif dh .is_float_dtype (promoted_dtype ):
1247- scalar_func = float
1248- else :
1249- scalar_func = bool
1250-
1216+ scalar_type = dh .get_scalar_type (promoted_dtype )
12511217 for idx in ah .ndindex (shape ):
12521218 out_idx = out [idx ]
12531219 x1_idx = _left [idx ]
12541220 x2_idx = _right [idx ]
12551221 assert out_idx .shape == x1_idx .shape == x2_idx .shape # sanity check
1256- assert bool (out_idx ) == (scalar_func (x1_idx ) != scalar_func (x2_idx ))
1222+ assert bool (out_idx ) == (scalar_type (x1_idx ) != scalar_type (x2_idx ))
12571223
12581224
12591225@pytest .mark .parametrize (
0 commit comments