|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from ..common import _aliases |
| 4 | +from ..common._helpers import _check_device |
2 | 5 |
|
3 | 6 | from .._internal import get_xp |
4 | 7 |
|
|
30 | 33 | result_type, |
31 | 34 | ) |
32 | 35 |
|
| 36 | +from typing import TYPE_CHECKING |
| 37 | +if TYPE_CHECKING: |
| 38 | + from typing import Optional, Union |
| 39 | + from ..common._typing import ndarray, Device, Dtype |
| 40 | + |
33 | 41 | import dask.array as da |
34 | 42 |
|
35 | 43 | isdtype = get_xp(np)(_aliases.isdtype) |
36 | 44 | astype = _aliases.astype |
37 | 45 |
|
38 | 46 | # Common aliases |
39 | | -arange = get_xp(da)(_aliases.arange) |
| 47 | + |
| 48 | +# This arange func is modified from the common one to |
| 49 | +# not pass stop/step as keyword arguments, which will cause |
| 50 | +# an error with dask |
| 51 | +def dask_arange( |
| 52 | + start: Union[int, float], |
| 53 | + /, |
| 54 | + stop: Optional[Union[int, float]] = None, |
| 55 | + step: Union[int, float] = 1, |
| 56 | + *, |
| 57 | + xp, |
| 58 | + dtype: Optional[Dtype] = None, |
| 59 | + device: Optional[Device] = None, |
| 60 | + **kwargs |
| 61 | +) -> ndarray: |
| 62 | + _check_device(xp, device) |
| 63 | + args = [start] |
| 64 | + if stop is not None: |
| 65 | + args.append(stop) |
| 66 | + else: |
| 67 | + # stop is None, so start is actually stop |
| 68 | + # prepend the default value for start which is 0 |
| 69 | + args.insert(0, 0) |
| 70 | + args.append(step) |
| 71 | + return xp.arange(*args, dtype=dtype, **kwargs) |
| 72 | + |
| 73 | +arange = get_xp(da)(dask_arange) |
40 | 74 | eye = get_xp(da)(_aliases.eye) |
41 | 75 |
|
42 | 76 | from functools import partial |
|
0 commit comments