11from __future__ import annotations
22
33from dataclasses import dataclass
4- from typing import Any , Sequence
4+ from typing import Any , Literal , Sequence
55
66import numpy as np
77from numpy .typing import NDArray
1212from optimagic .exceptions import InvalidBoundsError
1313from optimagic .parameters .tree_registry import get_registry
1414from optimagic .typing import PyTree , PyTreeRegistry
15+ from optimagic .utilities import fast_numpy_full
1516
1617
1718@dataclass (frozen = True )
@@ -60,8 +61,8 @@ def pre_process_bounds(
6061
6162
6263def _process_bounds_sequence (bounds : Sequence [tuple [float , float ]]) -> Bounds :
63- lower = np . full (len (bounds ), - np .inf )
64- upper = np . full (len (bounds ), np .inf )
64+ lower = fast_numpy_full (len (bounds ), fill_value = - np .inf )
65+ upper = fast_numpy_full (len (bounds ), fill_value = np .inf )
6566
6667 for i , (lb , ub ) in enumerate (bounds ):
6768 if lb is not None :
@@ -76,14 +77,14 @@ def get_internal_bounds(
7677 bounds : Bounds | None = None ,
7778 registry : PyTreeRegistry | None = None ,
7879 add_soft_bounds : bool = False ,
79- ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
80+ ) -> tuple [NDArray [np .float64 ] | None , NDArray [np .float64 ] | None ]:
8081 """Create consolidated and flattened bounds for params.
8182
8283 If params is a DataFrame with value column, the user provided bounds are
8384 extended with bounds from the params DataFrame.
8485
85- If no bounds are available the entry is set to minus np.inf for the lower bound and
86- np.inf for the upper bound.
86+ If no bounds are provided, we return None. If some bounds are available the missing
87+ entries are set to -np.inf for the lower bound and np.inf for the upper bound.
8788
8889 The bounds provided in `bounds` override bounds provided in params if both are
8990 specified (in the case where params is a DataFrame with bounds as a column).
@@ -109,10 +110,11 @@ def get_internal_bounds(
109110 add_soft_bounds = add_soft_bounds ,
110111 )
111112 if fast_path :
112- return _get_fast_path_bounds (
113- params = params ,
114- bounds = bounds ,
115- )
113+ return _get_fast_path_bounds (bounds )
114+
115+ # Handling of None-valued bounds in the slow path needs to be improved. Currently,
116+ # None-valued bounds are replaced with arrays of np.inf and -np.inf, and then
117+ # translated back to None if all entries are non-finite.
116118
117119 registry = get_registry (extended = True ) if registry is None else registry
118120 n_params = len (tree_leaves (params , registry = registry ))
@@ -149,11 +151,18 @@ def get_internal_bounds(
149151 msg = "Invalid bounds. Some lower bounds are larger than upper bounds."
150152 raise InvalidBoundsError (msg )
151153
154+ if np .isinf (lower_flat ).all ():
155+ lower_flat = None # type: ignore[assignment]
156+ if np .isinf (upper_flat ).all ():
157+ upper_flat = None # type: ignore[assignment]
158+
152159 return lower_flat , upper_flat
153160
154161
155162def _update_bounds_and_flatten (
156- nan_tree : PyTree , bounds : PyTree , kind : str
163+ nan_tree : PyTree ,
164+ bounds : PyTree ,
165+ kind : Literal ["lower_bound" , "upper_bound" , "soft_lower_bound" , "soft_upper_bound" ],
157166) -> NDArray [np .float64 ]:
158167 """Flatten bounds array and update it with bounds from params.
159168
@@ -213,7 +222,7 @@ def _is_fast_path(params: PyTree, bounds: Bounds, add_soft_bounds: bool) -> bool
213222 if not _is_1d_array (params ):
214223 out = False
215224
216- for bound in bounds .lower , bounds .upper :
225+ for bound in ( bounds .lower , bounds .upper ) :
217226 if not (_is_1d_array (bound ) or bound is None ):
218227 out = False
219228 return out
@@ -224,21 +233,27 @@ def _is_1d_array(candidate: Any) -> bool:
224233
225234
226235def _get_fast_path_bounds (
227- params : PyTree , bounds : Bounds
228- ) -> tuple [NDArray [np .float64 ], NDArray [np .float64 ]]:
236+ bounds : Bounds ,
237+ ) -> tuple [NDArray [np .float64 ] | None , NDArray [np .float64 ] | None ]:
229238 if bounds .lower is None :
230- # faster than np.full
231- lower_bounds = np .array ([- np .inf ] * len (params ))
239+ lower_bounds = None
232240 else :
233241 lower_bounds = bounds .lower .astype (float )
242+ if np .isinf (lower_bounds ).all ():
243+ lower_bounds = None
234244
235245 if bounds .upper is None :
236- # faster than np.full
237- upper_bounds = np .array ([np .inf ] * len (params ))
246+ upper_bounds = None
238247 else :
239248 upper_bounds = bounds .upper .astype (float )
240-
241- if (lower_bounds > upper_bounds ).any ():
249+ if np .isinf (upper_bounds ).all ():
250+ upper_bounds = None
251+
252+ if (
253+ lower_bounds is not None
254+ and upper_bounds is not None
255+ and (lower_bounds > upper_bounds ).any ()
256+ ):
242257 msg = "Invalid bounds. Some lower bounds are larger than upper bounds."
243258 raise InvalidBoundsError (msg )
244259
0 commit comments