Skip to content

Commit 8dc4850

Browse files
committed
correctly test graph_builder.py
1 parent 0ba1dbf commit 8dc4850

File tree

2 files changed

+43
-53
lines changed

2 files changed

+43
-53
lines changed

tests/test_graph_builder.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import glob
2+
import os
3+
import pathlib
4+
import sys
5+
6+
import neat
7+
import pytest
8+
import torch
9+
10+
# allow imports from repo root
11+
sys.path.insert(0, str(pathlib.Path(__file__).resolve().parents[1]))
12+
13+
from compare_encoders import optimizer_to_data
14+
from genome import OptimizerGenome
15+
from graph_builder import rebuild_and_script
16+
from reproduction import GuidedReproduction
17+
18+
19+
def make_config():
20+
config_path = os.path.join(pathlib.Path(__file__).resolve().parents[1], "neat-config")
21+
return neat.Config(
22+
OptimizerGenome,
23+
GuidedReproduction,
24+
neat.DefaultSpeciesSet,
25+
neat.DefaultStagnation,
26+
config_path,
27+
)
28+
29+
30+
@pytest.mark.parametrize("pt_path", glob.glob(os.path.join("computation_graphs", "optimizers", "*.pt")))
31+
def test_graph_builder_rebuilds_pt(pt_path):
32+
original = torch.jit.load(pt_path)
33+
data = optimizer_to_data(original)
34+
graph_dict = {
35+
"node_types": data.node_types,
36+
"edge_index": data.edge_index,
37+
"node_attributes": data.node_attributes,
38+
}
39+
40+
config = make_config()
41+
rebuilt = rebuild_and_script(graph_dict, config.genome_config, key=0)
42+
43+
assert original == rebuilt

tests/test_graph_dict.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)