Skip to content

Commit 9ccab70

Browse files
committed
fix code
1 parent 5316964 commit 9ccab70

File tree

5 files changed

+15
-7
lines changed

5 files changed

+15
-7
lines changed

_unittests/ut_translate_api/test_translate_classic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _run(cls, code):
406406
import onnx.helper
407407
import onnx.numpy_helper
408408
import onnx_array_api.translate_api.make_helper
409-
import onnx.reference.custom_element_types
409+
import ml_dtypes
410410

411411
def from_array_extended(tensor, name=None):
412412
dt = tensor.dtype
@@ -433,7 +433,7 @@ def from_array_extended(tensor, name=None):
433433
globs.update(onnx.helper.__dict__)
434434
globs.update(onnx.numpy_helper.__dict__)
435435
globs.update(onnx_array_api.translate_api.make_helper.__dict__)
436-
globs.update(onnx.reference.custom_element_types.__dict__)
436+
globs.update(ml_dtypes.__dict__)
437437
globs["from_array_extended"] = from_array_extended
438438
locs = {}
439439
try:

onnx_array_api/translate_api/base_emitter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
129129
if value[0].type == AttributeProto.TENSOR:
130130
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
131131
sdtype = repl.get(str(v.dtype), str(str(v.dtype)))
132+
package = "np" if hasattr(np, sdtype) else "ml_dtypes"
132133
return [], (
133-
f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), "
134+
f"from_array(np.array({v.tolist()}, dtype={package}.{sdtype}), "
134135
f"name={value[0].name!r})"
135136
)
136137
if isinstance(v, (int, float, list)):

onnx_array_api/translate_api/builder_emitter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Dict, List
2+
import numpy as np
23
from onnx import TensorProto
34
from onnx.numpy_helper import to_array
45
from .base_emitter import BaseEmitter
@@ -135,7 +136,10 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
135136
val = to_array(init)
136137
stype = str(val.dtype).split(".")[-1]
137138
name = self._clean_result_name(init.name)
138-
rows.append(f" {name} = np.array({val.tolist()}, dtype=np.{stype})")
139+
package = "np" if hasattr(np, stype) else "ml_dtypes"
140+
rows.append(
141+
f" {name} = np.array({val.tolist()}, dtype={package}.{stype})"
142+
)
139143
return rows
140144

141145
def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:

onnx_array_api/translate_api/inner_emitter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Dict, List, Optional, Tuple
2+
import numpy as np
23
from onnx import AttributeProto
34
from ..annotations import ELEMENT_TYPE_NAME
45
from .base_emitter import BaseEmitter
@@ -105,7 +106,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
105106
else:
106107
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
107108
else:
108-
sdtype = f"np.{sdtype}"
109+
sdtype = f"np.{sdtype}" if hasattr(np, sdtype) else f"ml_dtypes.{sdtype}"
109110

110111
return [
111112
"initializers.append(",
@@ -233,7 +234,7 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
233234
else:
234235
raise NotImplementedError(f"Unexpected dtype={sdtype}.")
235236
else:
236-
sdtype = f"np.{sdtype}"
237+
sdtype = f"np.{sdtype}" if hasattr(np, sdtype) else f"ml_dtypes.{sdtype}"
237238
if value.size <= 16:
238239
return [
239240
"initializers.append(",

onnx_array_api/translate_api/light_emitter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Any, Dict, List
2+
import numpy as np
23
from ..annotations import ELEMENT_TYPE_NAME
34
from .base_emitter import BaseEmitter
45

@@ -43,8 +44,9 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
4344
value = kwargs["value"]
4445
repl = {"bool": "bool_", "object": "object_", "str": "str_"}
4546
sdtype = repl.get(str(value.dtype), str(str(value.dtype)))
47+
package = "np" if hasattr(np, sdtype) else "ml_dtypes"
4648
return [
47-
f"cst(np.array({value.tolist()}, dtype=np.{sdtype}))",
49+
f"cst(np.array({value.tolist()}, dtype={package}.{sdtype}))",
4850
f"rename({name!r})",
4951
]
5052

0 commit comments

Comments
 (0)