@@ -311,6 +311,7 @@ def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc:
311311 """
312312 Wraps an elements strategy as a xps.from_dtype()-like function
313313 """
314+
314315 def from_dtype (dtype : DataType , ** kw ) -> st .SearchStrategy [float ]:
315316 assert len (kw ) == 0 # sanity check
316317 return strat
@@ -553,23 +554,6 @@ class UnaryCase(Case):
553554 cond : UnaryCheck
554555 check_result : UnaryResultCheck
555556
556- @classmethod
557- def from_strings (cls , cond_str : str , result_str : str ):
558- cond , cond_expr_template , cond_from_dtype = parse_cond (cond_str )
559- cond_expr = cond_expr_template .replace ("{}" , "x_i" )
560- _check_result , result_expr = parse_result (result_str )
561-
562- def check_result (i : float , result : float ) -> bool :
563- return _check_result (result )
564-
565- return cls (
566- cond_expr = cond_expr ,
567- cond = cond ,
568- cond_from_dtype = cond_from_dtype ,
569- result_expr = result_expr ,
570- check_result = check_result ,
571- )
572-
573557
574558r_unary_case = re .compile ("If ``x_i`` is (.+), the result is (.+)" )
575559r_even_int_round_case = re .compile (
@@ -578,7 +562,7 @@ def check_result(i: float, result: float) -> bool:
578562)
579563
580564
581- def trailing_halves_from_dtype (dtype : DataType ):
565+ def trailing_halves_from_dtype (dtype : DataType ) -> st . SearchStrategy [ float ] :
582566 m , M = dh .dtype_ranges [dtype ]
583567 return st .integers (math .ceil (m ) // 2 , math .floor (M ) // 2 ).map (lambda n : n * 0.5 )
584568
@@ -594,6 +578,13 @@ def trailing_halves_from_dtype(dtype: DataType):
594578)
595579
596580
581+ def make_unary_check_result (check_just_result : UnaryCheck ) -> UnaryResultCheck :
582+ def check_result (i : float , result : float ) -> bool :
583+ return check_just_result (result )
584+
585+ return check_result
586+
587+
597588def parse_unary_docstring (docstring : str ) -> List [UnaryCase ]:
598589 match = r_special_cases .search (docstring )
599590 if match is None :
@@ -608,10 +599,22 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
608599 continue
609600 if m := r_unary_case .search (case ):
610601 try :
611- case = UnaryCase .from_strings (* m .groups ())
602+ cond , cond_expr_template , cond_from_dtype = parse_cond (m .group (1 ))
603+ _check_result , result_expr = parse_result (m .group (2 ))
612604 except ParseError as e :
613605 warn (f"not machine-readable: '{ e .value } '" )
614606 continue
607+ cond_expr = cond_expr_template .replace ("{}" , "x_i" )
608+ # Do not define check_result in this function's body - see
609+ # parse_binary_case comment.
610+ check_result = make_unary_check_result (_check_result )
611+ case = UnaryCase (
612+ cond_expr = cond_expr ,
613+ cond = cond ,
614+ cond_from_dtype = cond_from_dtype ,
615+ result_expr = result_expr ,
616+ check_result = check_result ,
617+ )
615618 cases .append (case )
616619 elif m := r_even_int_round_case .search (case ):
617620 cases .append (even_int_round_case )
@@ -741,7 +744,7 @@ def check_result(i1: float, i2: float, result: float) -> bool:
741744 return check_result
742745
743746
744- def make_check_result (check_just_result : UnaryCheck ) -> BinaryResultCheck :
747+ def make_binary_check_result (check_just_result : UnaryCheck ) -> BinaryResultCheck :
745748 def check_result (i1 : float , i2 : float , result : float ) -> bool :
746749 return check_just_result (result )
747750
@@ -843,12 +846,12 @@ def partial_cond(i1: float, i2: float) -> bool:
843846
844847 else :
845848 unary_cond , expr_template , cond_from_dtype = parse_cond (value_str )
846- # Do not define partial_cond via the def keyword, as one
847- # partial_cond definition can mess up previous definitions
848- # in the partial_conds list. This is a hard-limitation of
849- # using local functions with the same name and that use the same
850- # outer variables (i.e. unary_cond). Use def in a called
851- # function avoids this problem.
849+ # Do not define partial_cond via the def keyword or lambda
850+ # expressions, as one partial_cond definition can mess up
851+ # previous definitions in the partial_conds list. This is a
852+ # hard-limitation of using local functions with the same name
853+ # and that use the same outer variables (i.e. unary_cond). Use
854+ # def in a called function avoids this problem.
852855 input_wrapper = None
853856 if m := r_input .match (input_str ):
854857 x_no = m .group (1 )
@@ -924,7 +927,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
924927 if result_m is None :
925928 raise ParseError (case_m .group (2 ))
926929 result_str = result_m .group (1 )
927- # Like with partial_cond, do not define check_result via the def keyword
930+ # Like with partial_cond, do not define check_result in this function's body.
928931 if m := r_array_element .match (result_str ):
929932 sign , x_no = m .groups ()
930933 result_expr = f"{ sign } x{ x_no } _i"
@@ -933,7 +936,7 @@ def _x2_cond_from_dtype(dtype) -> st.SearchStrategy[float]:
933936 )
934937 else :
935938 _check_result , result_expr = parse_result (result_m .group (1 ))
936- check_result = make_check_result (_check_result )
939+ check_result = make_binary_check_result (_check_result )
937940
938941 cond_expr = " and " .join (partial_exprs )
939942
0 commit comments