Skip to content

Commit 5bb585f

Browse files
[loops] Add snippet of re-usable loop code.
1 parent 41a0703 commit 5bb585f

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

inclearn/lib/loops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .loops import *

inclearn/lib/loops/loops.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import collections
2+
import logging
3+
4+
from torch import nn
5+
from tqdm import tqdm
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def single_loop(
11+
train_loader,
12+
val_loader,
13+
devices,
14+
network,
15+
n_epochs,
16+
optimizer,
17+
train_function,
18+
eval_function,
19+
task,
20+
n_tasks,
21+
scheduler=None,
22+
disable_progressbar=False,
23+
eval_every_x_epochs=None,
24+
early_stopping=None
25+
):
26+
best_epoch, best_acc = -1, -1.
27+
wait = 0
28+
29+
if len(devices) > 1:
30+
logger.info("Duplicating model on {} gpus.".format(len(devices)))
31+
training_network = nn.DataParallel(network, devices)
32+
else:
33+
training_network = network
34+
35+
for epoch in range(n_epochs):
36+
metrics = collections.defaultdict(float)
37+
38+
prog_bar = tqdm(
39+
train_loader,
40+
disable=disable_progressbar,
41+
ascii=True,
42+
bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}"
43+
)
44+
for batch_index, input_dict in enumerate(prog_bar, start=1):
45+
inputs, targets = input_dict["inputs"], input_dict["targets"]
46+
memory_flags = input_dict["memory_flags"]
47+
48+
optimizer.zero_grad()
49+
loss = train_function(training_network, inputs, targets, memory_flags, metrics)
50+
loss.backward()
51+
optimizer.step()
52+
53+
_print_metrics(metrics, prog_bar, epoch, n_epochs, batch_index, task, n_tasks)
54+
55+
if scheduler:
56+
scheduler.step(epoch)
57+
58+
if eval_every_x_epochs and epoch != 0 and epoch % eval_every_x_epochs == 0:
59+
training_network.eval()
60+
accuracy = eval_function(training_network, val_loader)
61+
training_network.train()
62+
63+
logger.info("Val accuracy: {}".format(accuracy))
64+
65+
if accuracy > best_acc:
66+
best_epoch = epoch
67+
best_acc = accuracy
68+
wait = 0
69+
else:
70+
wait += 1
71+
72+
if early_stopping and early_stopping["patience"] > wait:
73+
logger.warning("Early stopping!")
74+
break
75+
76+
if eval_every_x_epochs:
77+
logger.info("Best accuracy reached at epoch {} with {}%.".format(best_epoch, best_acc))
78+
79+
80+
def _print_metrics(metrics, prog_bar, epoch, nb_epochs, nb_batches, task, n_tasks):
81+
pretty_metrics = ", ".join(
82+
"{}: {}".format(metric_name, round(metric_value / nb_batches, 3))
83+
for metric_name, metric_value in metrics.items()
84+
)
85+
86+
prog_bar.set_description(
87+
"T{}/{}, E{}/{} => {}".format(task + 1, n_tasks, epoch + 1, nb_epochs, pretty_metrics)
88+
)

0 commit comments

Comments
 (0)