2929BinaryCheck = Callable [[float , float ], bool ]
3030
3131
32- def make_eq (v : float ) -> UnaryCheck :
32+ def make_strict_eq (v : float ) -> UnaryCheck :
3333 if math .isnan (v ):
3434 return math .isnan
3535 if v == 0 :
@@ -38,14 +38,14 @@ def make_eq(v: float) -> UnaryCheck:
3838 else :
3939 return ph .is_neg_zero
4040
41- def eq (i : float ) -> bool :
41+ def strict_eq (i : float ) -> bool :
4242 return i == v
4343
44- return eq
44+ return strict_eq
4545
4646
4747def make_neq (v : float ) -> UnaryCheck :
48- eq = make_eq (v )
48+ eq = make_strict_eq (v )
4949
5050 def neq (i : float ) -> bool :
5151 return not eq (i )
@@ -154,7 +154,8 @@ def parse_inline_code(inline_code: str) -> float:
154154 raise ValueParseError (inline_code )
155155
156156
157- r_not = re .compile ("not (?:equal to )?(.+)" )
157+ r_not = re .compile ("not (.+)" )
158+ r_equal_to = re .compile (f"equal to { r_code .pattern } " )
158159r_array_element = re .compile (r"``([+-]?)x([12])_i``" )
159160r_either_code = re .compile (f"either { r_code .pattern } or { r_code .pattern } " )
160161r_gt = re .compile (f"greater than { r_code .pattern } " )
@@ -217,9 +218,6 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
217218
218219
219220def parse_cond (cond_str : str ) -> Tuple [UnaryCheck , str , FromDtypeFunc ]:
220- if "equal to" in cond_str :
221- raise ValueParseError (cond_str ) # TODO
222-
223221 if m := r_not .match (cond_str ):
224222 cond_str = m .group (1 )
225223 not_cond = True
@@ -232,10 +230,15 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
232230 strat = None
233231 if m := r_code .match (cond_str ):
234232 value = parse_value (m .group (1 ))
235- cond = make_eq (value )
233+ cond = make_strict_eq (value )
236234 expr_template = "{} == " + m .group (1 )
237235 if not not_cond :
238236 strat = st .just (value )
237+ elif m := r_equal_to .match (cond_str ):
238+ value = parse_value (m .group (1 ))
239+ assert not math .isnan (value ) # sanity check
240+ cond = lambda i : i == value
241+ expr_template = "{} == " + m .group (1 )
239242 elif m := r_gt .match (cond_str ):
240243 value = parse_value (m .group (1 ))
241244 cond = make_gt (value )
@@ -251,7 +254,7 @@ def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, FromDtypeFunc]:
251254 elif m := r_either_code .match (cond_str ):
252255 v1 = parse_value (m .group (1 ))
253256 v2 = parse_value (m .group (2 ))
254- cond = make_or (make_eq (v1 ), make_eq (v2 ))
257+ cond = make_or (make_strict_eq (v1 ), make_strict_eq (v2 ))
255258 expr_template = "{} == " + m .group (1 ) + " or {} == " + m .group (2 )
256259 if not not_cond :
257260 strat = st .sampled_from ([v1 , v2 ])
@@ -334,7 +337,7 @@ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
334337def parse_result (result_str : str ) -> Tuple [UnaryCheck , str ]:
335338 if m := r_code .match (result_str ):
336339 value = parse_value (m .group (1 ))
337- check_result = make_eq (value ) # type: ignore
340+ check_result = make_strict_eq (value ) # type: ignore
338341 expr = m .group (1 )
339342 elif m := r_approx_value .match (result_str ):
340343 value = parse_value (m .group (1 ))
@@ -573,13 +576,13 @@ def make_eq_other_input_cond(
573576 if eq_to == BinaryCondArg .FIRST :
574577
575578 def cond (i1 : float , i2 : float ) -> bool :
576- eq = make_eq (input_wrapper (i1 ))
579+ eq = make_strict_eq (input_wrapper (i1 ))
577580 return eq (i2 )
578581
579582 elif eq_to == BinaryCondArg .SECOND :
580583
581584 def cond (i1 : float , i2 : float ) -> bool :
582- eq = make_eq (input_wrapper (i2 ))
585+ eq = make_strict_eq (input_wrapper (i2 ))
583586 return eq (i1 )
584587
585588 else :
@@ -599,13 +602,13 @@ def make_eq_input_check_result(
599602 if eq_to == BinaryCondArg .FIRST :
600603
601604 def check_result (i1 : float , i2 : float , result : float ) -> bool :
602- eq = make_eq (input_wrapper (i1 ))
605+ eq = make_strict_eq (input_wrapper (i1 ))
603606 return eq (result )
604607
605608 elif eq_to == BinaryCondArg .SECOND :
606609
607610 def check_result (i1 : float , i2 : float , result : float ) -> bool :
608- eq = make_eq (input_wrapper (i2 ))
611+ eq = make_strict_eq (input_wrapper (i2 ))
609612 return eq (result )
610613
611614 else :
0 commit comments