|
35 | 35 | UnaryParam = Param[str, Callable[[Array], Array], st.SearchStrategy[Array]] |
36 | 36 |
|
37 | 37 |
|
38 | | -def make_unary_params(func_name: str, dtypes: Sequence[DataType]) -> List[UnaryParam]: |
| 38 | +def make_unary_params( |
| 39 | + elwise_func_name: str, dtypes: Sequence[DataType] |
| 40 | +) -> List[UnaryParam]: |
39 | 41 | strat = xps.arrays(dtype=st.sampled_from(dtypes), shape=hh.shapes()) |
40 | | - |
41 | | - func = getattr(xp, func_name) |
42 | | - op = func_to_op[func_name] |
43 | | - |
44 | | - def op_func(x: Array) -> Array: |
45 | | - return getattr(x, op)() |
46 | | - |
| 42 | + func = getattr(xp, elwise_func_name) |
| 43 | + op_name = func_to_op[elwise_func_name] |
| 44 | + op = lambda x: getattr(x, op_name)() |
47 | 45 | return [ |
48 | | - pytest.param(func_name, func, strat, id=func_name), |
49 | | - pytest.param(op, op_func, strat, id=op), |
| 46 | + pytest.param(elwise_func_name, func, strat, id=elwise_func_name), |
| 47 | + pytest.param(op_name, op, strat, id=op_name), |
50 | 48 | ] |
51 | 49 |
|
52 | 50 |
|
@@ -79,101 +77,90 @@ class FuncType(Enum): |
79 | 77 | OP = auto() |
80 | 78 | IOP = auto() |
81 | 79 |
|
82 | | - @classmethod |
83 | | - def from_name(cls, name: str): |
84 | | - if name in dh.binary_op_to_symbol.keys(): |
85 | | - return cls.OP |
86 | | - elif name in dh.inplace_op_to_symbol.keys(): |
87 | | - return cls.IOP |
88 | | - else: |
89 | | - return cls.FUNC |
90 | | - |
91 | | - |
92 | | -def _make_binary_param( |
93 | | - func_name: str, right_is_scalar: bool, dtypes: Sequence[DataType] |
94 | | -) -> BinaryParam: |
95 | | - func_type = FuncType.from_name(func_name) |
96 | | - |
97 | | - if right_is_scalar: |
98 | | - left_sym = "x" |
99 | | - right_sym = "s" |
100 | | - else: |
101 | | - left_sym = "x1" |
102 | | - right_sym = "x2" |
103 | 80 |
|
| 81 | +def make_binary_params( |
| 82 | + elwise_func_name: str, dtypes: Sequence[DataType] |
| 83 | +) -> List[BinaryParam]: |
104 | 84 | dtypes_strat = st.sampled_from(dtypes) |
105 | | - shared_dtypes = st.shared(dtypes_strat) |
106 | | - if right_is_scalar: |
107 | | - left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes()) |
108 | | - right_strat = shared_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw)) |
109 | | - else: |
110 | | - if func_type is FuncType.IOP: |
111 | | - shared_shapes = st.shared(hh.shapes()) |
112 | | - left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) |
113 | | - right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) |
114 | | - else: |
115 | | - left_strat, right_strat = hh.two_mutual_arrays(dtypes) |
116 | | - |
117 | | - if func_type is FuncType.FUNC: |
118 | | - func = getattr(xp, func_name) |
119 | | - else: |
120 | | - op_sym = all_op_to_symbol[func_name] |
121 | | - expr = f"{left_sym} {op_sym} {right_sym}" |
122 | | - |
123 | | - if func_type is FuncType.OP: |
124 | | - |
125 | | - def func(l: Array, r: Union[Scalar, Array]) -> Array: |
126 | | - locals_ = {} |
127 | | - locals_[left_sym] = l |
128 | | - locals_[right_sym] = r |
129 | | - return eval(expr, locals_) |
130 | 85 |
|
| 86 | + def make_param( |
| 87 | + func_name: str, func_type: FuncType, right_is_scalar: bool |
| 88 | + ) -> BinaryParam: |
| 89 | + if right_is_scalar: |
| 90 | + left_sym = "x" |
| 91 | + right_sym = "s" |
131 | 92 | else: |
| 93 | + left_sym = "x1" |
| 94 | + right_sym = "x2" |
| 95 | + |
| 96 | + shared_dtypes = st.shared(dtypes_strat) |
| 97 | + if right_is_scalar: |
| 98 | + left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes()) |
| 99 | + right_strat = shared_dtypes.flatmap( |
| 100 | + lambda d: xps.from_dtype(d, **finite_kw) |
| 101 | + ) |
| 102 | + else: |
| 103 | + if func_type is FuncType.IOP: |
| 104 | + shared_shapes = st.shared(hh.shapes()) |
| 105 | + left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) |
| 106 | + right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes) |
| 107 | + else: |
| 108 | + left_strat, right_strat = hh.two_mutual_arrays(dtypes) |
| 109 | + |
| 110 | + if func_type is FuncType.FUNC: |
| 111 | + func = getattr(xp, func_name) |
| 112 | + else: |
| 113 | + op_sym = all_op_to_symbol[func_name] |
| 114 | + expr = f"{left_sym} {op_sym} {right_sym}" |
| 115 | + if func_type is FuncType.OP: |
132 | 116 |
|
133 | | - def func(l: Array, r: Union[Scalar, Array]) -> Array: |
134 | | - locals_ = {} |
135 | | - locals_[left_sym] = ah.asarray(l, copy=True) # prevents left mutating |
136 | | - locals_[right_sym] = r |
137 | | - exec(expr, locals_) |
138 | | - return locals_[left_sym] |
| 117 | + def func(l: Array, r: Union[Scalar, Array]) -> Array: |
| 118 | + locals_ = {} |
| 119 | + locals_[left_sym] = l |
| 120 | + locals_[right_sym] = r |
| 121 | + return eval(expr, locals_) |
139 | 122 |
|
140 | | - func.__name__ = func_name # for repr |
| 123 | + else: |
141 | 124 |
|
142 | | - if func_type is FuncType.IOP: |
143 | | - res_name = left_sym |
144 | | - else: |
145 | | - res_name = "out" |
| 125 | + def func(l: Array, r: Union[Scalar, Array]) -> Array: |
| 126 | + locals_ = {} |
| 127 | + locals_[left_sym] = ah.asarray( |
| 128 | + l, copy=True |
| 129 | + ) # prevents left mutating |
| 130 | + locals_[right_sym] = r |
| 131 | + exec(expr, locals_) |
| 132 | + return locals_[left_sym] |
146 | 133 |
|
147 | | - f_id = func_name |
148 | | - if right_is_scalar: |
149 | | - f_id += "(x, s)" |
150 | | - else: |
151 | | - f_id += "(x1, x2)" |
152 | | - |
153 | | - return pytest.param( |
154 | | - func_name, |
155 | | - func, |
156 | | - left_sym, |
157 | | - left_strat, |
158 | | - right_sym, |
159 | | - right_strat, |
160 | | - right_is_scalar, |
161 | | - res_name, |
162 | | - id=f_id, |
163 | | - ) |
| 134 | + func.__name__ = func_name # for repr |
164 | 135 |
|
| 136 | + if func_type is FuncType.IOP: |
| 137 | + res_name = left_sym |
| 138 | + else: |
| 139 | + res_name = "out" |
| 140 | + |
| 141 | + return pytest.param( |
| 142 | + func_name, |
| 143 | + func, |
| 144 | + left_sym, |
| 145 | + left_strat, |
| 146 | + right_sym, |
| 147 | + right_strat, |
| 148 | + right_is_scalar, |
| 149 | + res_name, |
| 150 | + id=f"{func_name}({left_sym}, {right_sym})", |
| 151 | + ) |
165 | 152 |
|
166 | | -def make_binary_params(func_name: str, dtypes: Sequence[DataType]) -> List[BinaryParam]: |
167 | | - op = func_to_op[func_name] |
| 153 | + op_name = func_to_op[elwise_func_name] |
168 | 154 | params = [ |
169 | | - _make_binary_param(func_name, False, dtypes), |
170 | | - _make_binary_param(op, False, dtypes), |
171 | | - _make_binary_param(op, True, dtypes), |
| 155 | + make_param(elwise_func_name, FuncType.FUNC, False), |
| 156 | + make_param(op_name, FuncType.OP, False), |
| 157 | + make_param(op_name, FuncType.OP, True), |
172 | 158 | ] |
173 | | - iop = f"__i{op[2:]}" |
174 | | - if iop in dh.inplace_op_to_symbol.keys(): |
175 | | - params.append(_make_binary_param(iop, False, dtypes)) |
176 | | - params.append(_make_binary_param(iop, True, dtypes)) |
| 159 | + iop_name = f"__i{op_name[2:]}" |
| 160 | + if iop_name in dh.inplace_op_to_symbol.keys(): |
| 161 | + params.append(make_param(iop_name, FuncType.IOP, False)) |
| 162 | + params.append(make_param(iop_name, FuncType.IOP, True)) |
| 163 | + |
177 | 164 | return params |
178 | 165 |
|
179 | 166 |
|
|
0 commit comments