Skip to content

Commit 3146ffd

Browse files
author
Jeff Yang
authored
refactor(tests): rewrite in pytest style (#52)
* refactor(tests): rewrite in pytest style * refactor(tests): rm strict tests * fix: tests * fix: address code review * fix: make iterables * fix: setUp -> set_up * fix: add internal tests * fix: no relative import * fix test
1 parent e9df949 commit 3146ffd

File tree

13 files changed

+889
-872
lines changed

13 files changed

+889
-872
lines changed

.github/run_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
set -xeu
3+
set -xu
44

55
if [ $1 == "generate" ]; then
66
python ./tests/generate.py
@@ -14,7 +14,7 @@ elif [ $1 == "unittest" ]; then
1414
for dir in $(find ./tests/dist -type d -mindepth 1 -maxdepth 1)
1515
do
1616
cd $dir
17-
pytest test_all.py -vra --color=yes --durations=0
17+
pytest
1818
cd ../../../
1919
done
2020
elif [ $1 == "integration" ]; then

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,8 @@ exclude = '''
2626
profile = "black"
2727
multi_line_output = 3
2828
supported_extensions = "py"
29+
30+
[tool.pytest.ini_options]
31+
minversion = "6.0"
32+
addopts = "-vra --color=yes --durations=0 --tb=short"
33+
python_files = "test_*.py _test_*.py"

templates/gan/_test_internal.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import logging
2+
from argparse import ArgumentParser, Namespace
3+
from pathlib import Path
4+
5+
import pytest
6+
import torch
7+
from config import get_default_parser
8+
from ignite.contrib.handlers import (
9+
ClearMLLogger,
10+
MLflowLogger,
11+
NeptuneLogger,
12+
PolyaxonLogger,
13+
TensorboardLogger,
14+
VisdomLogger,
15+
WandBLogger,
16+
)
17+
from ignite.contrib.handlers.base_logger import BaseLogger
18+
from ignite.engine import Engine
19+
from ignite.handlers.checkpoint import Checkpoint
20+
from ignite.handlers.early_stopping import EarlyStopping
21+
from ignite.handlers.timing import Timer
22+
from ignite.utils import setup_logger
23+
from test_all import set_up
24+
from torch import nn, optim
25+
from trainers import create_trainers
26+
from utils import hash_checkpoint, log_metrics, resume_from, setup_logging, get_handlers, get_logger
27+
28+
29+
def test_get_handlers(tmp_path):
30+
train_engine = Engine(lambda e, b: b)
31+
config = Namespace(
32+
output_dir=tmp_path,
33+
save_every_iters=1,
34+
n_saved=2,
35+
log_every_iters=1,
36+
with_pbars=False,
37+
with_pbar_on_iters=False,
38+
stop_on_nan=False,
39+
clear_cuda_cache=False,
40+
with_gpu_stats=False,
41+
patience=1,
42+
limit_sec=30,
43+
)
44+
bm_handler, es_handler, timer_handler = get_handlers(
45+
config=config,
46+
model=nn.Linear(1, 1),
47+
train_engine=train_engine,
48+
eval_engine=train_engine,
49+
metric_name="eval_loss",
50+
es_metric_name="eval_loss",
51+
)
52+
assert isinstance(bm_handler, (type(None), Checkpoint)), "Should be Checkpoint or None"
53+
assert isinstance(es_handler, (type(None), EarlyStopping)), "Should be EarlyStopping or None"
54+
assert isinstance(timer_handler, (type(None), Timer)), "Shoulde be Timer or None"
55+
56+
57+
def test_get_logger(tmp_path):
58+
config = Namespace(output_dir=tmp_path, logger_log_every_iters=1)
59+
train_engine = Engine(lambda e, b: b)
60+
optimizer = optim.Adam(nn.Linear(1, 1).parameters())
61+
logger_handler = get_logger(
62+
config=config,
63+
train_engine=train_engine,
64+
eval_engine=train_engine,
65+
optimizers=optimizer,
66+
)
67+
types = (
68+
BaseLogger,
69+
ClearMLLogger,
70+
MLflowLogger,
71+
NeptuneLogger,
72+
PolyaxonLogger,
73+
TensorboardLogger,
74+
VisdomLogger,
75+
WandBLogger,
76+
type(None),
77+
)
78+
assert isinstance(logger_handler, types), "Should be Ignite provided loggers or None"
79+
80+
81+
def test_create_trainers():
82+
model, optimizer, device, loss_fn, batch = set_up()
83+
real_labels = torch.ones(2, device=device)
84+
fake_labels = torch.zeros(2, device=device)
85+
train_engine = create_trainers(
86+
config=Namespace(use_amp=True),
87+
netD=model,
88+
netG=model,
89+
loss_fn=loss_fn,
90+
optimizerD=optimizer,
91+
optimizerG=optimizer,
92+
device=device,
93+
real_labels=real_labels,
94+
fake_labels=fake_labels,
95+
)
96+
assert isinstance(train_engine, Engine)
97+
98+
99+
def test_get_default_parser():
100+
parser = get_default_parser()
101+
assert isinstance(parser, ArgumentParser)
102+
assert not parser.add_help
103+
104+
105+
def test_log_metrics(capsys):
106+
engine = Engine(lambda e, b: None)
107+
engine.logger = setup_logger(format="%(message)s")
108+
engine.run(list(range(100)), max_epochs=2)
109+
log_metrics(engine, "train")
110+
captured = capsys.readouterr()
111+
assert captured.err.split("\n")[-2] == "train [2/200]: {}"
112+
113+
114+
def test_setup_logging(tmp_path):
115+
config = Namespace(verbose=True, output_dir=tmp_path)
116+
logger = setup_logging(config)
117+
assert logger.level == logging.INFO
118+
assert isinstance(logger, logging.Logger)
119+
assert next(tmp_path.rglob("*.log")).is_file()
120+
121+
122+
def test_hash_checkpoint(tmp_path):
123+
# download lightweight model
124+
model = torch.hub.load("pytorch/vision", "squeezenet1_0")
125+
# jit it
126+
scripted_model = torch.jit.script(model)
127+
# save jitted model : find a jitted checkpoint
128+
torch.jit.save(scripted_model, f"{tmp_path}/squeezenet1_0.ckptc")
129+
# download un-jitted model
130+
torch.hub.download_url_to_file(
131+
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
132+
f"{tmp_path}/squeezenet1_0.ckpt",
133+
)
134+
135+
checkpoint = f"{tmp_path}/squeezenet1_0.ckpt"
136+
hashed_fp, sha_hash = hash_checkpoint(checkpoint, False, tmp_path)
137+
model.load_state_dict(torch.load(hashed_fp), True)
138+
assert sha_hash[:8] == "b66bff10"
139+
assert hashed_fp.name == f"squeezenet1_0-{sha_hash[:8]}.pt"
140+
141+
checkpoint = f"{tmp_path}/squeezenet1_0.ckptc"
142+
hashed_fp, sha_hash = hash_checkpoint(checkpoint, True, tmp_path)
143+
scripted_model = torch.jit.load(hashed_fp)
144+
assert hashed_fp.name == f"squeezenet1_0-{sha_hash[:8]}.ptc"
145+
146+
147+
def test_resume_from_url(tmp_path, caplog):
148+
logger = logging.getLogger()
149+
logging.basicConfig(level=logging.INFO)
150+
checkpoint_fp = "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth"
151+
model = torch.hub.load("pytorch/vision", "squeezenet1_0")
152+
to_load = {"model": model}
153+
with caplog.at_level(logging.INFO):
154+
resume_from(to_load, checkpoint_fp, logger, model_dir=tmp_path)
155+
assert "Successfully resumed from a checkpoint" in caplog.messages[0], "checkpoint fail to load"
156+
157+
158+
def test_resume_from_fp(tmp_path, caplog):
159+
logger = logging.getLogger()
160+
logging.basicConfig(level=logging.INFO)
161+
torch.hub.download_url_to_file(
162+
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
163+
f"{tmp_path}/squeezenet1_0.pt",
164+
)
165+
checkpoint_fp = f"{tmp_path}/squeezenet1_0.pt"
166+
model = torch.hub.load("pytorch/vision", "squeezenet1_0")
167+
to_load = {"model": model}
168+
with caplog.at_level(logging.INFO):
169+
resume_from(to_load, checkpoint_fp, logger)
170+
assert "Successfully resumed from a checkpoint" in caplog.messages[0], "checkpoint fail to load"
171+
172+
torch.hub.download_url_to_file(
173+
"https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
174+
f"{tmp_path}/squeezenet1_0.pt",
175+
)
176+
checkpoint_fp = Path(f"{tmp_path}/squeezenet1_0.pt")
177+
model = torch.hub.load("pytorch/vision", "squeezenet1_0")
178+
to_load = {"model": model}
179+
with caplog.at_level(logging.INFO):
180+
resume_from(to_load, checkpoint_fp, logger)
181+
assert "Successfully resumed from a checkpoint" in caplog.messages[0], "checkpoint fail to load"
182+
183+
184+
def test_resume_from_error():
185+
with pytest.raises(FileNotFoundError, match=r"Given \w+ does not exist"):
186+
resume_from({}, "abcdef/", None)

templates/gan/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ pytorch-ignite>=0.4.4
33
torchvision>=0.8.0
44
matplotlib>=3.3.0
55
pandas
6+
pytest
67
{{ handler_deps }}
78
{{ logger_deps }}

0 commit comments

Comments
 (0)