@@ -153,7 +153,7 @@ def _check_device(xp, device):
153153 if device not in ["cpu" , None ]:
154154 raise ValueError (f"Unsupported device for NumPy: { device !r} " )
155155
156- # device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
156+ # device() is not on numpy.ndarray and to_device() is not on numpy.ndarray
157157# or cupy.ndarray. They are not included in array objects of this library
158158# because this library just reuses the respective ndarray classes without
159159# wrapping or subclassing them. These helper functions can be used instead of
@@ -230,12 +230,6 @@ def _torch_to_device(x, device, /, stream=None):
230230 raise NotImplementedError
231231 return x .to (device )
232232
233- def _jax_to_device (x , device , / , stream = None ):
234- import jax
235- if stream is not None :
236- raise NotImplementedError
237- return jax .device_put (x , device )
238-
239233def to_device (x : "Array" , device : "Device" , / , * , stream : "Optional[Union[int, Any]]" = None ) -> "Array" :
240234 """
241235 Copy the array from the device on which it currently resides to the specified ``device``.
@@ -276,7 +270,9 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A
276270 return x
277271 raise ValueError (f"Unsupported device { device !r} " )
278272 elif is_jax_array (x ):
279- return _jax_to_device (x , device , stream = stream )
273+ # This import adds to_device to x
274+ import jax .experimental .array_api
275+ return x .to_device (device , stream = stream )
280276 return x .to_device (device , stream = stream )
281277
282278def size (x ):
0 commit comments