@@ -2,107 +2,107 @@ using .tensorboard: NodeDef, AttrValue, NameAttrList
22using . tensorboard: var"AttrValue.ListValue" as AttrValue_ListValue
33
44"""
5- log_graph
5+ log_graph
66"""
77function log_graph (logger:: TBLogger , g:: AbstractGraph ; step = nothing , nodelabel:: Vector{String} = map (string, vertices (g)), nodeop:: Vector{String} = map (string, vertices (g)), nodedevice:: Vector{String} = fill (" cpu" , nv (g)), nodevalue:: Vector{Any} = fill (nothing , nv (g)))
8- nv (g) == length (nodelabel) || throw (ArgumentError (" length of nodelable must be same as number of vertices" ))
9- nv (g) == length (nodeop) || throw (ArgumentError (" length of nodeop must be same as number of vertices" ))
10- nv (g) == length (nodedevice) || throw (ArgumentError (" length of nodedevice must be same as number of vertices" ))
11- nv (g) == length (nodevalue) || throw (ArgumentError (" length of nodevalue must be same as number of vertices" ))
12- summ = SummaryCollection (graph_summary (g, nodelabel, nodeop, nodedevice, nodevalue))
8+ nv (g) == length (nodelabel) || throw (ArgumentError (" length of nodelable must be same as number of vertices" ))
9+ nv (g) == length (nodeop) || throw (ArgumentError (" length of nodeop must be same as number of vertices" ))
10+ nv (g) == length (nodedevice) || throw (ArgumentError (" length of nodedevice must be same as number of vertices" ))
11+ nv (g) == length (nodevalue) || throw (ArgumentError (" length of nodevalue must be same as number of vertices" ))
12+ summ = SummaryCollection (graph_summary (g, nodelabel, nodeop, nodedevice, nodevalue))
1313 write_event (logger. file, make_event (logger, summ, step= step))
1414end
1515
1616function graph_summary (g, nodelabel, nodeop, nodedevice, nodevalue)
17- nodes = Vector {NodeDef} ()
18- for v in vertices (g)
19- name = nodelabel[v]
20- op = nodeop[v]
21- input = [nodelabel[x] for x in inneighbors (g, v)]
22- device = nodedevice[v]
23- attr = Dict {String, AttrValue} ()
24- x = nodevalue[v]
25- if isa (x, AbstractString)
26- attr[" value" ] = AttrValue (OneOf (:s , Vector {UInt8} (x)))
27- attr[" dtype" ] = AttrValue (OneOf (:_type ,jltype2tf (typeof (x))))
28- elseif isa (x, Integer)
29- attr[" value" ] = AttrValue (OneOf (:i ,Int64 (x)))
30- attr[" dtype" ] = AttrValue (OneOf (:_type , jltype2tf (typeof (x))))
31- elseif isa (x, Real)
32- attr[" value" ] = AttrValue (OneOf (:f , Float32 (x)))
33- attr[" dtype" ] = AttrValue (OneOf (:_type , jltype2tf (typeof (x))))
34- elseif isa (x, Bool)
35- attr[" value" ] = AttrValue (OneOf (:b , x))
36- attr[" dtype" ] = AttrValue (OneOf (:_type , jltype2tf (typeof (x))))
37- elseif isa (x, AbstractArray)
38- shape = TensorShapeProto ([TensorShapeProto_Dim (d, " " ) for d in (collect (size (x)))], false )
39- t = TensorProto (dtype = jltype2tf (eltype (x)), tensor_shape = shape, tensor_content = serialize_proto (string (x)))
40- attr[" value" ] = AttrValue (OneOf (:tensor , t))
41- listvalue = AttrValue_ListValue (Vector {Vector{UInt8}} (),
42- Vector {Int64} (),
43- Vector {Float32} (),
44- Vector {Bool} (),
45- Vector {var"#DataType".T} (),
46- [shape],
47- Vector {TensorProto} (),
48- Vector {NameAttrList} ())
49- attr[" _output_shapes" ] = AttrValue (OneOf (:list , listvalue))
50- elseif isa (x, Tuple)
51- listvalue = AttrValue_ListValue ([Vector {UInt8} (repr (y)) for y in x],
52- Vector {Int64} (),
53- Vector {Float32} (),
54- Vector {Bool} (),
55- Vector {var"#DataType".T} (),
56- Vector {TensorShapeProto} (),
57- Vector {TensorProto} (),
58- Vector {NameAttrList} ())
59- attr[" value" ] = AttrValue (OneOf (:list , listvalue))
60- shape = TensorShapeProto ([TensorShapeProto_Dim (length (x), " " )], false )
61- listvalue = AttrValue_ListValue (Vector {Vector{UInt8}} (),
62- Vector {Int64} (),
63- Vector {Float32} (),
64- Vector {Bool} (),
65- Vector {var"#DataType".T} (),
66- [shape],
67- Vector {TensorProto} (),
68- Vector {NameAttrList} ())
69- attr[" _output_shapes" ] = AttrValue (OneOf (:list , listvalue))
70- elseif isa (x, Function)
71- attr[" value" ] = AttrValue (OneOf (:func , NameAttrList (name = repr (x))))
72- else
73- # donothing
74- end
75- node = NodeDef (name, op, input, device, attr, nothing , nothing )
76- push! (nodes, node)
77- end
78- GraphDef (nodes, nothing , 0 , nothing , nothing )
17+ nodes = Vector {NodeDef} ()
18+ for v in vertices (g)
19+ name = nodelabel[v]
20+ op = nodeop[v]
21+ input = [nodelabel[x] for x in inneighbors (g, v)]
22+ device = nodedevice[v]
23+ attr = Dict {String, AttrValue} ()
24+ x = nodevalue[v]
25+ if isa (x, AbstractString)
26+ attr[" value" ] = AttrValue (OneOf (:s , Vector {UInt8} (x)))
27+ attr[" dtype" ] = AttrValue (OneOf (:_type ,jltype2tf (typeof (x))))
28+ elseif isa (x, Integer)
29+ attr[" value" ] = AttrValue (OneOf (:i ,Int64 (x)))
30+ attr[" dtype" ] = AttrValue (OneOf (:_type , jltype2tf (typeof (x))))
31+ elseif isa (x, Real)
32+ attr[" value" ] = AttrValue (OneOf (:f , Float32 (x)))
33+ attr[" dtype" ] = AttrValue (OneOf (:_type , jltype2tf (typeof (x))))
34+ elseif isa (x, Bool)
35+ attr[" value" ] = AttrValue (OneOf (:b , x))
36+ attr[" dtype" ] = AttrValue (OneOf (:_type , jltype2tf (typeof (x))))
37+ elseif isa (x, AbstractArray)
38+ shape = TensorShapeProto ([TensorShapeProto_Dim (d, " " ) for d in (collect (size (x)))], false )
39+ t = TensorProto (dtype = jltype2tf (eltype (x)), tensor_shape = shape, tensor_content = serialize_proto (string (x)))
40+ attr[" value" ] = AttrValue (OneOf (:tensor , t))
41+ listvalue = AttrValue_ListValue (Vector {Vector{UInt8}} (),
42+ Vector {Int64} (),
43+ Vector {Float32} (),
44+ Vector {Bool} (),
45+ Vector {var"#DataType".T} (),
46+ [shape],
47+ Vector {TensorProto} (),
48+ Vector {NameAttrList} ())
49+ attr[" _output_shapes" ] = AttrValue (OneOf (:list , listvalue))
50+ elseif isa (x, Tuple)
51+ listvalue = AttrValue_ListValue ([Vector {UInt8} (repr (y)) for y in x],
52+ Vector {Int64} (),
53+ Vector {Float32} (),
54+ Vector {Bool} (),
55+ Vector {var"#DataType".T} (),
56+ Vector {TensorShapeProto} (),
57+ Vector {TensorProto} (),
58+ Vector {NameAttrList} ())
59+ attr[" value" ] = AttrValue (OneOf (:list , listvalue))
60+ shape = TensorShapeProto ([TensorShapeProto_Dim (length (x), " " )], false )
61+ listvalue = AttrValue_ListValue (Vector {Vector{UInt8}} (),
62+ Vector {Int64} (),
63+ Vector {Float32} (),
64+ Vector {Bool} (),
65+ Vector {var"#DataType".T} (),
66+ [shape],
67+ Vector {TensorProto} (),
68+ Vector {NameAttrList} ())
69+ attr[" _output_shapes" ] = AttrValue (OneOf (:list , listvalue))
70+ elseif isa (x, Function)
71+ attr[" value" ] = AttrValue (OneOf (:func , NameAttrList (name = repr (x))))
72+ else
73+ # donothing
74+ end
75+ node = NodeDef (name, op, input, device, attr, nothing , nothing )
76+ push! (nodes, node)
77+ end
78+ GraphDef (nodes, nothing , 0 , nothing , nothing )
7979end
8080
8181function jltype2tf (dtype:: DataType )
82- nodetype =
83- dtype == UInt8 ? _DataType. DT_UINT8 :
84- dtype == UInt16 ? _DataType. DT_UINT16 :
85- dtype == UInt32 ? _DataType. DT_UINT32 :
86- dtype == UInt64 ? _DataType. DT_UINT64 :
87- dtype == Int8 ? _DataType. DT_INT8 :
88- dtype == Int16 ? _DataType. DT_INT16 :
89- dtype == Int32 ? _DataType. DT_INT32 :
90- dtype == Int64 ? _DataType. DT_INT64 :
91- dtype == Float16 ? _DataType. DT_BFLOAT16 :
92- dtype == Float32 ? _DataType. DT_FLOAT :
93- dtype == Float64 ? _DataType. DT_DOUBLE :
94- dtype <: AbstractString ? _DataType. DT_STRING :
95- dtype == Bool ? _DataType. DT_BOOL :
96- dtype ∈ [Complex{Float32},
97- Complex{Float16},
98- Complex{UInt8},
99- Complex{UInt16},
100- Complex{UInt32},
101- Complex{Int8},
102- Complex{Int16},
103- Complex{Int32}] ? _DataType. DT_COMPLEX64 :
104- dtype ∈ [Complex{Float64},
105- Complex{UInt64},
106- Complex{Int64}] ? _DataType. DT_COMPLEX128 :
107- _DataType. DT_INVALID
82+ nodetype =
83+ dtype == UInt8 ? _DataType. DT_UINT8 :
84+ dtype == UInt16 ? _DataType. DT_UINT16 :
85+ dtype == UInt32 ? _DataType. DT_UINT32 :
86+ dtype == UInt64 ? _DataType. DT_UINT64 :
87+ dtype == Int8 ? _DataType. DT_INT8 :
88+ dtype == Int16 ? _DataType. DT_INT16 :
89+ dtype == Int32 ? _DataType. DT_INT32 :
90+ dtype == Int64 ? _DataType. DT_INT64 :
91+ dtype == Float16 ? _DataType. DT_BFLOAT16 :
92+ dtype == Float32 ? _DataType. DT_FLOAT :
93+ dtype == Float64 ? _DataType. DT_DOUBLE :
94+ dtype <: AbstractString ? _DataType. DT_STRING :
95+ dtype == Bool ? _DataType. DT_BOOL :
96+ dtype ∈ [Complex{Float32},
97+ Complex{Float16},
98+ Complex{UInt8},
99+ Complex{UInt16},
100+ Complex{UInt32},
101+ Complex{Int8},
102+ Complex{Int16},
103+ Complex{Int32}] ? _DataType. DT_COMPLEX64 :
104+ dtype ∈ [Complex{Float64},
105+ Complex{UInt64},
106+ Complex{Int64}] ? _DataType. DT_COMPLEX128 :
107+ _DataType. DT_INVALID
108108end
0 commit comments