@@ -584,37 +584,31 @@ The second part modifies it.
584584
585585 onnx.save(gs.export_onnx(graph), " modified.onnx" )
586586
587- numpy API for onnx
588- ++++++++++++++++++
587+ Graph Builder API
588+ +++++++++++++++++
589589
590- See :ref: `l-numpy-api-onnx `. This API was introduced to create graphs
591- by using numpy API. If a function is defined only with numpy,
592- it should be possible to use the exact same code to create the
593- corresponding onnx graph. That's what this API tries to achieve.
594- It works with the exception of control flow. In that case, the function
595- produces different onnx graphs depending on the execution path.
590+ See :ref: `l-graph-api `. This API is very similar to what *skl2onnx * implements.
591+ It is still about adding nodes to a graph but some tasks are automated such as
592+ naming the results or converting constants to onnx classes.
596593
597594.. runpython ::
598595 :showcode:
599596
600597 import numpy as np
601- from onnx_array_api.npx import jit_onnx
598+ from onnx_array_api.graph_api import GraphBuilder
602599 from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
603600
604- def l2_loss(x, y):
605- return ((x - y) ** 2).sum(keepdims=1 )
606-
607- jitted_myloss = jit_onnx(l2_loss )
608- dummy = np.array([0 ], dtype=np.float32 )
609-
610- # The function is executed. Only then a onnx graph is created.
611- # One is created depending on the input type.
612- jitted_myloss(dummy, dummy )
601+ g = GraphBuilder()
602+ g.make_tensor_input("X", np.float32, (None, None) )
603+ g.make_tensor_input("Y", np.float32, (None, None))
604+ r1 = g.op.Sub("X", "Y" )
605+ r2 = g.op.Pow(r1, np.array([2 ], dtype=np.int64) )
606+ g.op.ReduceSum(r2, outputs=["Z"])
607+ g.make_tensor_output("Z", np.float32, (None, None))
608+
609+ onx = g.to_onnx( )
613610
614- # get_onnx only works if it was executed once or at least with
615- # the same input type
616- model = jitted_myloss.get_onnx()
617- print(onnx_simple_text_plot(model))
611+ print(onnx_simple_text_plot(onx))
618612
619613Light API
620614+++++++++
@@ -647,3 +641,35 @@ There is no eager mode.
647641 )
648642
649643 print(onnx_simple_text_plot(model))
644+
645+ numpy API for onnx
646+ ++++++++++++++++++
647+
648+ See :ref: `l-numpy-api-onnx `. This API was introduced to create graphs
649+ by using numpy API. If a function is defined only with numpy,
650+ it should be possible to use the exact same code to create the
651+ corresponding onnx graph. That's what this API tries to achieve.
652+ It works with the exception of control flow. In that case, the function
653+ produces different onnx graphs depending on the execution path.
654+
655+ .. runpython ::
656+ :showcode:
657+
658+ import numpy as np
659+ from onnx_array_api.npx import jit_onnx
660+ from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
661+
662+ def l2_loss(x, y):
663+ return ((x - y) ** 2).sum(keepdims=1)
664+
665+ jitted_myloss = jit_onnx(l2_loss)
666+ dummy = np.array([0], dtype=np.float32)
667+
668+ # The function is executed. Only then a onnx graph is created.
669+ # One is created depending on the input type.
670+ jitted_myloss(dummy, dummy)
671+
672+ # get_onnx only works if it was executed once or at least with
673+ # the same input type
674+ model = jitted_myloss.get_onnx()
675+ print(onnx_simple_text_plot(model))
0 commit comments