1010"""
1111
1212import math
13+ from enum import Enum , auto
1314from typing import Callable , List , Sequence , Union
1415
1516import pytest
2526from . import xps
2627from .typing import Array , DataType , Param , Scalar
2728
28-
2929# We might as well use this implementation rather than xp.broadcast_shapes()
3030from .test_broadcasting import broadcast_shapes
3131
@@ -74,38 +74,53 @@ def op_func(x: Array) -> Array:
7474]
7575
7676
77+ class FuncType (Enum ):
78+ FUNC = auto ()
79+ OP = auto ()
80+ IOP = auto ()
81+
82+ @classmethod
83+ def from_name (cls , name : str ):
84+ if name in dh .binary_op_to_symbol .keys ():
85+ return cls .OP
86+ elif name in dh .inplace_op_to_symbol .keys ():
87+ return cls .IOP
88+ else :
89+ return cls .FUNC
90+
91+
7792def _make_binary_param (
7893 func_name : str , right_is_scalar : bool , dtypes : Sequence [DataType ]
7994) -> BinaryParam :
80- if func_name in dh .binary_op_to_symbol .keys ():
81- func_type = "op"
82- elif func_name in dh .inplace_op_to_symbol .keys ():
83- func_type = "iop"
84- else :
85- func_type = "func"
95+ func_type = FuncType .from_name (func_name )
8696
87- left_sym , right_sym = ("x" , "s" ) if right_is_scalar else ("x1" , "x2" )
97+ if right_is_scalar :
98+ left_sym = "x"
99+ right_sym = "s"
100+ else :
101+ left_sym = "x1"
102+ right_sym = "x2"
88103
89104 dtypes_strat = st .sampled_from (dtypes )
90105 shared_dtypes = st .shared (dtypes_strat )
91106 if right_is_scalar :
92107 left_strat = xps .arrays (dtype = shared_dtypes , shape = hh .shapes ())
93108 right_strat = shared_dtypes .flatmap (lambda d : xps .from_dtype (d , ** finite_kw ))
94109 else :
95- if func_type == "iop" :
110+ if func_type is FuncType . IOP :
96111 shared_shapes = st .shared (hh .shapes ())
97112 left_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
98113 right_strat = xps .arrays (dtype = shared_dtypes , shape = shared_shapes )
99114 else :
100115 left_strat , right_strat = hh .two_mutual_arrays (dtypes )
101116
102- if func_type == "func" :
117+ if func_type is FuncType . FUNC :
103118 func = getattr (xp , func_name )
104119 else :
105120 op_sym = all_op_to_symbol [func_name ]
106121 expr = f"{ left_sym } { op_sym } { right_sym } "
107122
108- if func_type == "op" :
123+ if func_type is FuncType . OP :
109124
110125 def func (l : Array , r : Union [Scalar , Array ]) -> Array :
111126 locals_ = {}
@@ -124,7 +139,7 @@ def func(l: Array, r: Union[Scalar, Array]) -> Array:
124139
125140 func .__name__ = func_name # for repr
126141
127- if func_type == "iop" :
142+ if func_type is FuncType . IOP :
128143 res_name = left_sym
129144 else :
130145 res_name = "out"
0 commit comments