@@ -100,9 +100,53 @@ function loadreadme(ds::TUDataset)
100100 return Text (read (path, String))
101101end
102102
103- full_vertexval_types (ds:: TUDataset ) = cat_tuple_types (node_labels_type (ds), node_attributes_type (ds))
104- full_edgeval_types (ds:: TUDataset ) = cat_tuple_types (edge_labels_type (ds), edge_attributes_type (ds))
105- full_graphval_types (ds:: TUDataset ) = cat_tuple_types (graph_labels_type (ds), graph_attributes_type (ds))
103+ function full_vertexval_types (ds:: TUDataset ; resolve_categories:: Bool )
104+
105+ if resolve_categories
106+ return cat_tuple_types (node_labels_type (ds), node_attributes_type (ds))
107+ end
108+
109+ NLT = node_labels_type (ds)
110+ RNLT = if NLT <: NamedTuple
111+ NamedTuple{fieldnames (NLT), NTuple{length (fieldnames (NLT)), Int8}}
112+ else
113+ NTuple{length (fieldnames (NLT)), Int8}
114+ end
115+
116+ return cat_tuple_types (RNLT, node_attributes_type (ds))
117+ end
118+
119+ function full_edgeval_types (ds:: TUDataset ; resolve_categories:: Bool )
120+
121+ if resolve_categories
122+ return cat_tuple_types (edge_labels_type (ds), edge_attributes_type (ds))
123+ end
124+
125+ ELT = edge_labels_type (ds)
126+ RELT = if ELT <: NamedTuple
127+ NamedTuple{fieldnames (ELT), NTuple{length (fieldnames (ELT)), Int8}}
128+ else
129+ NTuple{length (fieldnames (ELT)), Int8}
130+ end
131+
132+ return cat_tuple_types (RELT, edge_attributes_type (ds))
133+ end
134+
135+ function full_graphval_types (ds:: TUDataset ; resolve_categories:: Bool )
136+
137+ if resolve_categories
138+ return cat_tuple_types (graph_labels_type (ds), graph_attributes_type (ds))
139+ end
140+
141+ GLT = graph_labels_type (ds)
142+ RGLT = if GLT <: NamedTuple
143+ NamedTuple{fieldnames (GLT), NTuple{length (fieldnames (GLT)), Int8}}
144+ else
145+ NTuple{length (fieldnames (GLT)), Int8}
146+ end
147+
148+ return cat_tuple_types (RGLT, graph_attributes_type (ds))
149+ end
106150
107151prefix (ds:: TUDataset ) = dataset_name (ds) * ' _'
108152
@@ -133,17 +177,23 @@ function load_node_attributes(ds::TUDataset)
133177 return CSV. File (node_attributes_path, header= false , strict= true , types= [node_attributes_type (ds). types... ])
134178end
135179
136- function load_full_vertexvals (ds:: TUDataset , n)
180+ function load_full_vertexvals (ds:: TUDataset , n; resolve_categories :: Bool )
137181
138182 node_labels = load_node_labels (ds:: TUDataset )
139183 node_attributes = load_node_attributes (ds:: TUDataset )
140184
141- V_VALS = full_vertexval_types (ds)
185+ V_VALS = full_vertexval_types (ds; resolve_categories = resolve_categories )
142186
143187 vertexvals = Vector {V_VALS} (undef, n)
144188
145189 for i ∈ 1 : n
146- label_i = node_labels == nothing ? () : tuple (node_labels_map (ds, node_labels[i][1 ]))
190+ label_i = if node_labels == nothing
191+ ()
192+ elseif resolve_categories
193+ tuple (node_labels_map (ds, node_labels[i][1 ]))
194+ else
195+ tuple (node_labels[i][1 ])
196+ end
147197 attr_i = node_attributes == nothing ? () : (node_attributes[i]. .. ,)
148198 vertexvals[i] = V_VALS ((label_i... , attr_i... ))
149199 end
@@ -165,17 +215,23 @@ function load_edge_attributes(ds::TUDataset)
165215 return CSV. File (path, header= false , strict= true , types= [edge_attributes_type (ds). types... ])
166216end
167217
168- function load_full_edgevals (ds:: TUDataset , m)
218+ function load_full_edgevals (ds:: TUDataset , m; resolve_categories :: Bool )
169219
170220 edge_labels = load_edge_labels (ds:: TUDataset )
171221 edge_attributes = load_edge_attributes (ds:: TUDataset )
172222
173- E_VALS = full_edgeval_types (ds)
223+ E_VALS = full_edgeval_types (ds; resolve_categories = resolve_categories )
174224
175225 edgevals = Vector {E_VALS} (undef, m)
176226
177227 for i ∈ 1 : m
178- label_i = edge_labels == nothing ? () : tuple (edge_labels_map (ds, edge_labels[i][1 ]))
228+ label_i = if edge_labels == nothing
229+ ()
230+ elseif resolve_categories
231+ tuple (edge_labels_map (ds, edge_labels[i][1 ]))
232+ else
233+ tuple (edge_labels[i][1 ])
234+ end
179235 attr_i = edge_attributes == nothing ? () : (edge_attributes[i]. .. ,)
180236 edgevals[i] = E_VALS ((label_i... , attr_i... ))
181237 end
@@ -204,17 +260,23 @@ function load_graph_attributes(ds::TUDataset)
204260 return CSV. File (path, header= false , strict= true , delim= ' ,' , types= [graph_attributes_type (ds). types... ])
205261end
206262
207- function load_full_graphvals (ds:: TUDataset , N)
263+ function load_full_graphvals (ds:: TUDataset , N; resolve_categories :: Bool )
208264
209265 graph_labels = load_graph_labels (ds:: TUDataset )
210266 graph_attributes = load_graph_attributes (ds:: TUDataset )
211267
212- G_VALS = full_graphval_types (ds)
268+ G_VALS = full_graphval_types (ds; resolve_categories = resolve_categories )
213269
214270 graphvals = Vector {G_VALS} (undef, N)
215271
216272 for i ∈ 1 : N
217- label_i = graph_labels == nothing ? () : tuple (graph_labels_map (ds, graph_labels[i][1 ]))
273+ label_i = if graph_labels == nothing
274+ ()
275+ elseif resolve_categories
276+ tuple (graph_labels_map (ds, graph_labels[i][1 ]))
277+ else
278+ tuple (graph_labels[i][1 ])
279+ end
218280 attr_i = graph_attributes == nothing ? () : (graph_attributes[i]. .. ,)
219281 graphvals[i] = G_VALS ((label_i... , attr_i... ))
220282 end
@@ -882,7 +944,7 @@ node_attributes_type(::TRIANGLESDataset) = Tuple{Int8}
882944# TODO this is quite ugly
883945# TODO fix type instabilities
884946# TODO check data for inconsistencies
885- function loadgraphs (ds:: TUDataset )
947+ function loadgraphs (ds:: TUDataset ; resolve_categories :: Bool = false )
886948
887949 edgelist = load_edgelist (ds)
888950 graph_indicator = load_graphindicator (ds)
@@ -891,18 +953,27 @@ function loadgraphs(ds::TUDataset)
891953 m = length (edgelist) # number of edges
892954 N = maximum (row -> row[1 ], graph_indicator) # number of graphs
893955
894- vertex_ids = zeros (Int, n + 1 )
895- edge_ids = zeros (graph_eltype (ds), m)
896- graph_ids = zeros (Int, N + 1 )
897-
898- vertexvals = load_full_vertexvals (ds, n)
899- edgevals = load_full_edgevals (ds, m)
900- graphvals = load_full_graphvals (ds, N)
956+ vertexvals = load_full_vertexvals (ds, n; resolve_categories= resolve_categories)
957+ edgevals = load_full_edgevals (ds, m; resolve_categories= resolve_categories)
958+ graphvals = load_full_graphvals (ds, N; resolve_categories= resolve_categories)
901959
902960 @assert length (vertexvals) == n
903961 @assert length (edgevals) == m
904962 @assert length (graphvals) == N
905963
964+ return _to_ValGraphCollection (graph_eltype (ds), edgelist, graph_indicator, vertexvals, edgevals, graphvals)
965+ end
966+
967+ function _to_ValGraphCollection (V, edgelist, graph_indicator, vertexvals, edgevals, graphvals)
968+
969+ n = length (graph_indicator) # number of vertices
970+ m = length (edgelist) # number of edges
971+ N = maximum (row -> row[1 ], graph_indicator) # number of graphs
972+
973+ vertex_ids = zeros (Int, n + 1 )
974+ edge_ids = zeros (V, m)
975+ graph_ids = zeros (Int, N + 1 )
976+
906977 k = 0
907978 for (v_id, row) in enumerate (graph_indicator)
908979
0 commit comments