Skip to content

Commit e349d63

Browse files
Add IMDb datasets
1 parent 7ee120d commit e349d63

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

src/TUDatasets.jl

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ function __init__()
314314
FingerprintDataset(),
315315
COLLABDataset(),
316316
DBLP_v1Dataset(),
317+
IMDB_BINARYDataset(),
318+
IMDB_MULTIDataset(),
317319
REDDIT_BINARYDataset(),
318320
COLORS_3Dataset(),
319321
QM9Dataset(),
@@ -822,6 +824,36 @@ edge_labels_map(::DBLP_v1Dataset, i) = ("P2P", "P2W", "W2W")[i + 1]
822824
# TODO there is actually a node map with over 40000 entries defined in readme.txt
823825
node_labels_type(::DBLP_v1Dataset) = Tuple{UInt16}
824826

827+
## --------------------------------------
828+
## IMDB-BINARY
829+
## --------------------------------------
830+
831+
struct IMDB_BINARYDataset <: TUDataset end
832+
833+
dataset_name(::IMDB_BINARYDataset) = "IMDB-BINARY"
834+
835+
dataset_hash(::IMDB_BINARYDataset) = "b291ec8b26d85c70faa2ba0a2433e1f407ed2ef5d0fc072d36b9a95e49a1bb27"
836+
837+
dataset_references(::IMDB_BINARYDataset) = [14]
838+
839+
graph_eltype(::IMDB_BINARYDataset) = Int16
840+
841+
graph_labels_type(::IMDB_BINARYDataset) = @NamedTuple{class::Int8}
842+
843+
## --------------------------------------
844+
## IMDB-MULTI
845+
## --------------------------------------
846+
847+
struct IMDB_MULTIDataset <: TUDataset end
848+
849+
dataset_name(::IMDB_MULTIDataset) = "IMDB-MULTI"
850+
851+
dataset_hash(::IMDB_MULTIDataset) = "a4a302149ebf4c76fa1f0fb108baff89fcbf9d35de306b18f27a8419b9a1a690"
852+
853+
dataset_references(::IMDB_MULTIDataset) = [14]
854+
855+
graph_labels_type(::IMDB_MULTIDataset) = @NamedTuple{class::Int8}
856+
825857
## --------------------------------------
826858
## REDDIT-BINARY
827859
## --------------------------------------
@@ -961,10 +993,10 @@ function loadgraphs(ds::TUDataset; resolve_categories::Bool=false)
961993
@assert length(edgevals) == m
962994
@assert length(graphvals) == N
963995

964-
return _to_ValGraphCollection(graph_eltype(ds), edgelist, graph_indicator, vertexvals, edgevals, graphvals)
996+
return _to_ValGraphCollection(ds, graph_eltype(ds), edgelist, graph_indicator, vertexvals, edgevals, graphvals)
965997
end
966998

967-
function _to_ValGraphCollection(V, edgelist, graph_indicator, vertexvals, edgevals, graphvals)
999+
function _to_ValGraphCollection(ds::TUDataset, V, edgelist, graph_indicator, vertexvals, edgevals, graphvals)
9681000

9691001
n = length(graph_indicator) # number of vertices
9701002
m = length(edgelist) # number of edges

0 commit comments

Comments
 (0)