Skip to content

Commit 7ee120d

Browse files
Add resolve-categories keyword argument to loadgraphs (#6)
1 parent b4cd684 commit 7ee120d

File tree

3 files changed

+99
-23
lines changed

3 files changed

+99
-23
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ TUDatasets.BZRDataset
2323

2424
# Load QM9 from TUDatasets. This dataset contains 129433 molecules represented as graphs.
2525
# The resulting ValGraphCollection is an immutable collection of graphs.
26-
julia> qm9 = loadgraphs(TUDatasets.QM9Dataset())
26+
julia> qm9 = loadgraphs(TUDatasets.QM9Dataset(); resolve_categories=true)
2727
129433-element ValGraphCollection of graphs with
2828
eltype: Int8
2929
vertex value types: (Bool, Bool, Bool, Bool, Bool, Int8, Bool, Bool, Bool, Bool, Bool, Bool, Int64, Float64, Float64, Float64)

src/TUDatasets.jl

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,53 @@ function loadreadme(ds::TUDataset)
100100
return Text(read(path, String))
101101
end
102102

103-
full_vertexval_types(ds::TUDataset) = cat_tuple_types(node_labels_type(ds), node_attributes_type(ds))
104-
full_edgeval_types(ds::TUDataset) = cat_tuple_types(edge_labels_type(ds), edge_attributes_type(ds))
105-
full_graphval_types(ds::TUDataset) = cat_tuple_types(graph_labels_type(ds), graph_attributes_type(ds))
103+
function full_vertexval_types(ds::TUDataset; resolve_categories::Bool)
104+
105+
if resolve_categories
106+
return cat_tuple_types(node_labels_type(ds), node_attributes_type(ds))
107+
end
108+
109+
NLT = node_labels_type(ds)
110+
RNLT = if NLT <: NamedTuple
111+
NamedTuple{fieldnames(NLT), NTuple{length(fieldnames(NLT)), Int8}}
112+
else
113+
NTuple{length(fieldnames(NLT)), Int8}
114+
end
115+
116+
return cat_tuple_types(RNLT, node_attributes_type(ds))
117+
end
118+
119+
function full_edgeval_types(ds::TUDataset; resolve_categories::Bool)
120+
121+
if resolve_categories
122+
return cat_tuple_types(edge_labels_type(ds), edge_attributes_type(ds))
123+
end
124+
125+
ELT = edge_labels_type(ds)
126+
RELT = if ELT <: NamedTuple
127+
NamedTuple{fieldnames(ELT), NTuple{length(fieldnames(ELT)), Int8}}
128+
else
129+
NTuple{length(fieldnames(ELT)), Int8}
130+
end
131+
132+
return cat_tuple_types(RELT, edge_attributes_type(ds))
133+
end
134+
135+
function full_graphval_types(ds::TUDataset; resolve_categories::Bool)
136+
137+
if resolve_categories
138+
return cat_tuple_types(graph_labels_type(ds), graph_attributes_type(ds))
139+
end
140+
141+
GLT = graph_labels_type(ds)
142+
RGLT = if GLT <: NamedTuple
143+
NamedTuple{fieldnames(GLT), NTuple{length(fieldnames(GLT)), Int8}}
144+
else
145+
NTuple{length(fieldnames(GLT)), Int8}
146+
end
147+
148+
return cat_tuple_types(RGLT, graph_attributes_type(ds))
149+
end
106150

107151
prefix(ds::TUDataset) = dataset_name(ds) * '_'
108152

@@ -133,17 +177,23 @@ function load_node_attributes(ds::TUDataset)
133177
return CSV.File(node_attributes_path, header=false, strict=true, types=[node_attributes_type(ds).types...])
134178
end
135179

136-
function load_full_vertexvals(ds::TUDataset, n)
180+
function load_full_vertexvals(ds::TUDataset, n; resolve_categories::Bool)
137181

138182
node_labels = load_node_labels(ds::TUDataset)
139183
node_attributes = load_node_attributes(ds::TUDataset)
140184

141-
V_VALS = full_vertexval_types(ds)
185+
V_VALS = full_vertexval_types(ds; resolve_categories=resolve_categories)
142186

143187
vertexvals = Vector{V_VALS}(undef, n)
144188

145189
for i 1:n
146-
label_i = node_labels == nothing ? () : tuple(node_labels_map(ds, node_labels[i][1]))
190+
label_i = if node_labels == nothing
191+
()
192+
elseif resolve_categories
193+
tuple(node_labels_map(ds, node_labels[i][1]))
194+
else
195+
tuple(node_labels[i][1])
196+
end
147197
attr_i = node_attributes == nothing ? () : (node_attributes[i]...,)
148198
vertexvals[i] = V_VALS((label_i..., attr_i...))
149199
end
@@ -165,17 +215,23 @@ function load_edge_attributes(ds::TUDataset)
165215
return CSV.File(path, header=false, strict=true, types=[edge_attributes_type(ds).types...])
166216
end
167217

168-
function load_full_edgevals(ds::TUDataset, m)
218+
function load_full_edgevals(ds::TUDataset, m; resolve_categories::Bool)
169219

170220
edge_labels = load_edge_labels(ds::TUDataset)
171221
edge_attributes = load_edge_attributes(ds::TUDataset)
172222

173-
E_VALS = full_edgeval_types(ds)
223+
E_VALS = full_edgeval_types(ds; resolve_categories=resolve_categories)
174224

175225
edgevals = Vector{E_VALS}(undef, m)
176226

177227
for i 1:m
178-
label_i = edge_labels == nothing ? () : tuple(edge_labels_map(ds, edge_labels[i][1]))
228+
label_i = if edge_labels == nothing
229+
()
230+
elseif resolve_categories
231+
tuple(edge_labels_map(ds, edge_labels[i][1]))
232+
else
233+
tuple(edge_labels[i][1])
234+
end
179235
attr_i = edge_attributes == nothing ? () : (edge_attributes[i]...,)
180236
edgevals[i] = E_VALS((label_i..., attr_i...))
181237
end
@@ -204,17 +260,23 @@ function load_graph_attributes(ds::TUDataset)
204260
return CSV.File(path, header=false, strict=true, delim=',', types=[graph_attributes_type(ds).types...])
205261
end
206262

207-
function load_full_graphvals(ds::TUDataset, N)
263+
function load_full_graphvals(ds::TUDataset, N; resolve_categories::Bool)
208264

209265
graph_labels = load_graph_labels(ds::TUDataset)
210266
graph_attributes = load_graph_attributes(ds::TUDataset)
211267

212-
G_VALS = full_graphval_types(ds)
268+
G_VALS = full_graphval_types(ds; resolve_categories=resolve_categories)
213269

214270
graphvals = Vector{G_VALS}(undef, N)
215271

216272
for i 1:N
217-
label_i = graph_labels == nothing ? () : tuple(graph_labels_map(ds, graph_labels[i][1]))
273+
label_i = if graph_labels == nothing
274+
()
275+
elseif resolve_categories
276+
tuple(graph_labels_map(ds, graph_labels[i][1]))
277+
else
278+
tuple(graph_labels[i][1])
279+
end
218280
attr_i = graph_attributes == nothing ? () : (graph_attributes[i]...,)
219281
graphvals[i] = G_VALS((label_i..., attr_i...))
220282
end
@@ -882,7 +944,7 @@ node_attributes_type(::TRIANGLESDataset) = Tuple{Int8}
882944
# TODO this is quite ugly
883945
# TODO fix type instabilities
884946
# TODO check data for inconsistencies
885-
function loadgraphs(ds::TUDataset)
947+
function loadgraphs(ds::TUDataset; resolve_categories::Bool=false)
886948

887949
edgelist = load_edgelist(ds)
888950
graph_indicator = load_graphindicator(ds)
@@ -891,18 +953,27 @@ function loadgraphs(ds::TUDataset)
891953
m = length(edgelist) # number of edges
892954
N = maximum(row -> row[1], graph_indicator) # number of graphs
893955

894-
vertex_ids = zeros(Int, n + 1)
895-
edge_ids = zeros(graph_eltype(ds), m)
896-
graph_ids = zeros(Int, N + 1)
897-
898-
vertexvals = load_full_vertexvals(ds, n)
899-
edgevals = load_full_edgevals(ds, m)
900-
graphvals = load_full_graphvals(ds, N)
956+
vertexvals = load_full_vertexvals(ds, n; resolve_categories=resolve_categories)
957+
edgevals = load_full_edgevals(ds, m; resolve_categories=resolve_categories)
958+
graphvals = load_full_graphvals(ds, N; resolve_categories=resolve_categories)
901959

902960
@assert length(vertexvals) == n
903961
@assert length(edgevals) == m
904962
@assert length(graphvals) == N
905963

964+
return _to_ValGraphCollection(graph_eltype(ds), edgelist, graph_indicator, vertexvals, edgevals, graphvals)
965+
end
966+
967+
function _to_ValGraphCollection(V, edgelist, graph_indicator, vertexvals, edgevals, graphvals)
968+
969+
n = length(graph_indicator) # number of vertices
970+
m = length(edgelist) # number of edges
971+
N = maximum(row -> row[1], graph_indicator) # number of graphs
972+
973+
vertex_ids = zeros(Int, n + 1)
974+
edge_ids = zeros(V, m)
975+
graph_ids = zeros(Int, N + 1)
976+
906977
k = 0
907978
for (v_id, row) in enumerate(graph_indicator)
908979

src/graphdataset.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@ function dataset_message(ds::GraphDataset)
3636
end
3737

3838
"""
39-
loadgraphs(ds::GraphDataset)
39+
loadgraphs(ds::GraphDataset; resolve_categories::Bool=false)
4040
4141
Loads multiple graphs from a dataset `ds`.
4242
4343
Either this method or `loadgraph` must be implement for new subytpes of `GraphDataset`.
4444
45+
# Keywords
46+
- `resolve_categories`: Some graph metadata might be of categorical form (i.e strings).
47+
If this argument is true, try to resolve that category instead of keeping a numerical
48+
value.
49+
4550
### See also
4651
[`loadgraph`](@Ref), `GraphDataset`](@Ref)
4752
"""
48-
loadgraphs(ds::GraphDataset)
53+
loadgraphs(ds::GraphDataset; resolve_categories::Bool)
4954

5055
## ----------------------------------------
5156
## optional methods to implement

0 commit comments

Comments
 (0)