@@ -49,7 +49,11 @@ def is_array_api_obj(x):
4949 or _is_torch_array (x ) \
5050 or hasattr (x , '__array_namespace__' )
5151
52- def get_namespace (* xs , _use_compat = True ):
52+ def _check_api_version (api_version ):
53+ if api_version is not None and api_version != '2021.12' :
54+ raise ValueError ("Only the 2021.12 version of the array API specification is currently supported" )
55+
56+ def get_namespace (* xs , api_version = None , _use_compat = True ):
5357 """
5458 Get the array API compatible namespace for the arrays `xs`.
5559
@@ -61,28 +65,34 @@ def your_function(x, y):
6165 xp = array_api_compat.get_namespace(x, y)
6266 # Now use xp as the array library namespace
6367 return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
68+
69+ api_version should be the newest version of the spec that you need support
70+ for (currently the compat library wrapped APIs only support v2021.12).
6471 """
6572 namespaces = set ()
6673 for x in xs :
6774 if isinstance (x , (tuple , list )):
6875 namespaces .add (get_namespace (* x , _use_compat = _use_compat ))
6976 elif hasattr (x , '__array_namespace__' ):
70- namespaces .add (x .__array_namespace__ ())
77+ namespaces .add (x .__array_namespace__ (api_version = api_version ))
7178 elif _is_numpy_array (x ):
79+ _check_api_version (api_version )
7280 if _use_compat :
7381 from .. import numpy as numpy_namespace
7482 namespaces .add (numpy_namespace )
7583 else :
7684 import numpy as np
7785 namespaces .add (np )
7886 elif _is_cupy_array (x ):
87+ _check_api_version (api_version )
7988 if _use_compat :
8089 from .. import cupy as cupy_namespace
8190 namespaces .add (cupy_namespace )
8291 else :
8392 import cupy as cp
8493 namespaces .add (cp )
8594 elif _is_torch_array (x ):
95+ _check_api_version (api_version )
8696 if _use_compat :
8797 from .. import torch as torch_namespace
8898 namespaces .add (torch_namespace )
0 commit comments