11import cmath
22import math
3+ from functools import wraps
34from inspect import getfullargspec
4- from typing import Any , Dict , Optional , Sequence , Tuple , Union
5+ from typing import Any , Callable , Dict , Optional , Sequence , Tuple , Union
56
67from . import _array_module as xp
78from . import dtype_helpers as dh
@@ -122,6 +123,7 @@ def assert_dtype(
122123 >>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)
123124
124125 """
126+ __tracebackhide__ = True
125127 in_dtypes = in_dtype if isinstance (in_dtype , Sequence ) and not isinstance (in_dtype , str ) else [in_dtype ]
126128 f_in_dtypes = dh .fmt_types (tuple (in_dtypes ))
127129 f_out_dtype = dh .dtype_to_name [out_dtype ]
@@ -149,6 +151,7 @@ def assert_kw_dtype(
149151 >>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)
150152
151153 """
154+ __tracebackhide__ = True
152155 f_kw_dtype = dh .dtype_to_name [kw_dtype ]
153156 f_out_dtype = dh .dtype_to_name [out_dtype ]
154157 msg = (
@@ -166,6 +169,7 @@ def assert_default_float(func_name: str, out_dtype: DataType):
166169 >>> assert_default_float('ones', out.dtype)
167170
168171 """
172+ __tracebackhide__ = True
169173 f_dtype = dh .dtype_to_name [out_dtype ]
170174 f_default = dh .dtype_to_name [dh .default_float ]
171175 msg = (
@@ -183,6 +187,7 @@ def assert_default_complex(func_name: str, out_dtype: DataType):
183187 >>> assert_default_complex('asarray', out.dtype)
184188
185189 """
190+ __tracebackhide__ = True
186191 f_dtype = dh .dtype_to_name [out_dtype ]
187192 f_default = dh .dtype_to_name [dh .default_complex ]
188193 msg = (
@@ -200,6 +205,7 @@ def assert_default_int(func_name: str, out_dtype: DataType):
200205 >>> assert_default_int('full', out.dtype)
201206
202207 """
208+ __tracebackhide__ = True
203209 f_dtype = dh .dtype_to_name [out_dtype ]
204210 f_default = dh .dtype_to_name [dh .default_int ]
205211 msg = (
@@ -217,6 +223,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty
217223 >>> assert_default_int('argmax', out.dtype)
218224
219225 """
226+ __tracebackhide__ = True
220227 f_dtype = dh .dtype_to_name [out_dtype ]
221228 msg = (
222229 f"{ repr_name } ={ f_dtype } , should be the default index dtype, "
@@ -240,6 +247,7 @@ def assert_shape(
240247 >>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))
241248
242249 """
250+ __tracebackhide__ = True
243251 if isinstance (out_shape , int ):
244252 out_shape = (out_shape ,)
245253 if isinstance (expected , int ):
@@ -273,6 +281,7 @@ def assert_result_shape(
273281 >>> assert out.shape == (3, 3)
274282
275283 """
284+ __tracebackhide__ = True
276285 if expected is None :
277286 expected = sh .broadcast_shapes (* in_shapes )
278287 f_in_shapes = " . " .join (str (s ) for s in in_shapes )
@@ -307,6 +316,7 @@ def assert_keepdimable_shape(
307316 >>> assert out2.shape == (1, 1)
308317
309318 """
319+ __tracebackhide__ = True
310320 if keepdims :
311321 shape = tuple (1 if axis in axes else side for axis , side in enumerate (in_shape ))
312322 else :
@@ -337,6 +347,7 @@ def assert_0d_equals(
337347 >>> assert res[0] == x[0]
338348
339349 """
350+ __tracebackhide__ = True
340351 msg = (
341352 f"{ out_repr } ={ out_val } , but should be { x_repr } ={ x_val } "
342353 f"[{ func_name } ({ fmt_kw (kw )} )]"
@@ -369,6 +380,7 @@ def assert_scalar_equals(
369380 >>> assert int(out) == 5
370381
371382 """
383+ __tracebackhide__ = True
372384 repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
373385 f_func = f"{ func_name } ({ fmt_kw (kw )} )"
374386 if type_ in [bool , int ]:
@@ -401,6 +413,7 @@ def assert_fill(
401413 >>> assert xp.all(out == 42)
402414
403415 """
416+ __tracebackhide__ = True
404417 msg = f"out not filled with { fill_value } [{ func_name } ({ fmt_kw (kw )} )]\n { out = } "
405418 if cmath .isnan (fill_value ):
406419 assert xp .all (xp .isnan (out )), msg
@@ -443,6 +456,7 @@ def assert_array_elements(
443456 >>> assert xp.all(out == x)
444457
445458 """
459+ __tracebackhide__ = True
446460 dh .result_type (out .dtype , expected .dtype ) # sanity check
447461 assert_shape (func_name , out_shape = out .shape , expected = expected .shape , kw = kw ) # sanity check
448462 f_func = f"[{ func_name } ({ fmt_kw (kw )} )]"
@@ -470,3 +484,18 @@ def assert_array_elements(
470484 assert xp .all (
471485 out == expected
472486 ), f"{ out_repr } not as expected { f_func } \n { out_repr } ={ out !r} \n { expected = } "
487+
488+
489+ def _make_wrapped_assert_helper (assert_helper : Callable ) -> Callable :
490+ @wraps (assert_helper )
491+ def wrapped_assert_helper (* args , ** kwargs ):
492+ __tracebackhide__ = True
493+ assert_helper (* args , ** kwargs )
494+
495+ return wrapped_assert_helper
496+
497+
498+ for func_name in __all__ :
499+ if func_name .startswith ("assert" ):
500+ assert_helper = globals ()[func_name ]
501+ globals ()[func_name ] = _make_wrapped_assert_helper (assert_helper )
0 commit comments