33"""
44from typing import Any , Optional
55import numpy as np
6- from onnx import TensorProto
76from ..npx .npx_functions import (
87 all ,
98 abs ,
1615 reshape ,
1716 take ,
1817)
18+ from ..npx .npx_functions import full as generic_full
1919from ..npx .npx_functions import ones as generic_ones
2020from ..npx .npx_functions import zeros as generic_zeros
2121from ..npx .npx_numpy_tensors import EagerNumpyTensor
22- from ..npx .npx_types import DType , ElemType , TensorType , OptParType
22+ from ..npx .npx_types import DType , ElemType , TensorType , OptParType , ParType , Scalar
2323from ._onnx_common import template_asarray
2424from . import _finalize_array_api
2525
3131 "astype" ,
3232 "empty" ,
3333 "equal" ,
34+ "full" ,
3435 "isdtype" ,
3536 "isfinite" ,
3637 "isnan" ,
@@ -58,7 +59,7 @@ def asarray(
5859
5960def ones (
6061 shape : TensorType [ElemType .int64 , "I" , (None ,)],
61- dtype : OptParType [DType ] = DType ( TensorProto . FLOAT ) ,
62+ dtype : OptParType [DType ] = None ,
6263 order : OptParType [str ] = "C" ,
6364) -> TensorType [ElemType .numerics , "T" ]:
6465 if isinstance (shape , tuple ):
@@ -76,7 +77,7 @@ def ones(
7677
7778def empty (
7879 shape : TensorType [ElemType .int64 , "I" , (None ,)],
79- dtype : OptParType [DType ] = DType ( TensorProto . FLOAT ) ,
80+ dtype : OptParType [DType ] = None ,
8081 order : OptParType [str ] = "C" ,
8182) -> TensorType [ElemType .numerics , "T" ]:
8283 raise RuntimeError (
@@ -87,7 +88,7 @@ def empty(
8788
8889def zeros (
8990 shape : TensorType [ElemType .int64 , "I" , (None ,)],
90- dtype : OptParType [DType ] = DType ( TensorProto . FLOAT ) ,
91+ dtype : OptParType [DType ] = None ,
9192 order : OptParType [str ] = "C" ,
9293) -> TensorType [ElemType .numerics , "T" ]:
9394 if isinstance (shape , tuple ):
@@ -103,6 +104,32 @@ def zeros(
103104 return generic_zeros (shape , dtype = dtype , order = order )
104105
105106
107+ def full (
108+ shape : TensorType [ElemType .int64 , "I" , (None ,)],
109+ fill_value : ParType [Scalar ] = None ,
110+ dtype : OptParType [DType ] = None ,
111+ order : OptParType [str ] = "C" ,
112+ ) -> TensorType [ElemType .numerics , "T" ]:
113+ if fill_value is None :
114+ raise TypeError ("fill_value cannot be None" )
115+ value = fill_value
116+ if isinstance (shape , tuple ):
117+ return generic_full (
118+ EagerNumpyTensor (np .array (shape , dtype = np .int64 )),
119+ fill_value = value ,
120+ dtype = dtype ,
121+ order = order ,
122+ )
123+ if isinstance (shape , int ):
124+ return generic_full (
125+ EagerNumpyTensor (np .array ([shape ], dtype = np .int64 )),
126+ fill_value = value ,
127+ dtype = dtype ,
128+ order = order ,
129+ )
130+ return generic_full (shape , fill_value = value , dtype = dtype , order = order )
131+
132+
106133def _finalize ():
107134 """
108135 Adds common attributes to Array API defined in this modules
0 commit comments