Skip to content

Commit 8452868

Browse files
[fixedrepresentation] Update model with simplier interface.
1 parent 7d874e3 commit 8452868

File tree

1 file changed

+87
-85
lines changed

1 file changed

+87
-85
lines changed

inclearn/models/fixedrepresentation.py

Lines changed: 87 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -2,114 +2,116 @@
22

33
import numpy as np
44
import torch
5-
from torch import nn
65
from torch.nn import functional as F
7-
from tqdm import trange
86

9-
from inclearn import factory, utils
10-
from inclearn.models.base import IncrementalLearner
7+
from inclearn.lib import factory, loops, network, utils
8+
from inclearn.models import IncrementalLearner
119

12-
LOGGER = logging.Logger("IncLearn", level="INFO")
10+
logger = logging.getLogger(__name__)
1311

1412

1513
class FixedRepresentation(IncrementalLearner):
16-
"""Base incremental learner.
17-
18-
Methods are called in this order (& repeated for each new task):
19-
20-
1. set_task_info
21-
2. before_task
22-
3. train_task
23-
4. after_task
24-
5. eval_task
25-
"""
2614

2715
def __init__(self, args):
28-
super().__init__()
29-
30-
self._epochs = 70
31-
32-
self._n_classes = args["increment"]
33-
self._device = args["device"]
16+
self._device = args["device"][0]
17+
self._multiple_devices = args["device"]
18+
19+
self._opt_name = args["optimizer"]
20+
self._lr = args["lr"]
21+
self._lr_decay = args["lr_decay"]
22+
self._weight_decay = args["weight_decay"]
23+
self._n_epochs = args["epochs"]
24+
self._scheduling = args["scheduling"]
25+
26+
logger.info("Initializing FixedRepresentation")
27+
28+
self._network = network.BasicNet(
29+
args["convnet"],
30+
convnet_kwargs=args.get("convnet_config", {}),
31+
classifier_kwargs=args.get(
32+
"classifier_config", {
33+
"type": "fc",
34+
"use_bias": True,
35+
"use_multi_fc": True
36+
}
37+
),
38+
device=self._device
39+
)
40+
41+
self._n_classes = 0
42+
self._old_model = None
3443

35-
self._features_extractor = factory.get_resnet(args["convnet"], nf=64,
36-
zero_init_residual=True)
44+
def _before_task(self, data_loader, val_loader):
45+
self._n_classes += self._task_size
46+
self._network.add_classes(self._task_size)
3747

38-
self._classifiers = [nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=False).to(self._device)]
39-
torch.nn.init.kaiming_normal_(self._classifiers[0].weight)
40-
self.add_module("clf_" + str(self._n_classes), self._classifiers[0])
48+
self._optimizer = factory.get_optimizer(
49+
self._network.classifier.classifier[-1].parameters(), self._opt_name, self._lr,
50+
self._weight_decay
51+
)
52+
if self._scheduling is None:
53+
self._scheduler = None
54+
else:
55+
self._scheduler = torch.optim.lr_scheduler.MultiStepLR(
56+
self._optimizer, self._scheduling, gamma=self._lr_decay
57+
)
4158

42-
self.to(self._device)
59+
def _train_task(self, train_loader, val_loader):
60+
loops.single_loop(
61+
train_loader,
62+
val_loader,
63+
self._multiple_devices,
64+
self._network,
65+
self._n_epochs,
66+
self._optimizer,
67+
scheduler=self._scheduler,
68+
train_function=self._forward_loss,
69+
eval_function=self._accuracy,
70+
task=self._task,
71+
n_tasks=self._n_tasks
72+
)
73+
74+
def _after_task(self, inc_dataset):
75+
self._old_model = self._network.copy().freeze().eval().to(self._device)
76+
self._network.on_task_end()
4377

44-
def forward(self, x):
45-
feats = self._features_extractor(x)
78+
def _eval_task(self, loader):
79+
ypred, ytrue = [], []
4680

47-
logits = []
48-
for clf in self._classifiers:
49-
logits.append(clf(feats))
81+
for input_dict in loader:
82+
with torch.no_grad():
83+
logits = self._network(input_dict["inputs"].to(self._device))["logits"]
5084

51-
return torch.cat(logits, dim=1)
85+
ytrue.append(input_dict["targets"].numpy())
86+
ypred.append(torch.softmax(logits, dim=1).cpu().numpy())
5287

53-
def _before_task(self, data_loader, val_loader):
54-
if self._task != 0:
55-
self._add_n_classes(self._task_size)
88+
ytrue = np.concatenate(ytrue)
89+
ypred = np.concatenate(ypred)
5690

57-
self._optimizer = factory.get_optimizer(
58-
filter(lambda x: x.requires_grad, self.parameters()),
59-
"sgd", 0.1)
60-
self._scheduler = torch.optim.lr_scheduler.MultiStepLR(self._optimizer, [50, 60], gamma=0.2)
91+
return ypred, ytrue
6192

62-
def _get_params(self):
63-
return [self._features_extractor.parameters()]
93+
def _accuracy(self, loader):
94+
ypred, ytrue = self._eval_task(loader)
95+
ypred = ypred.argmax(dim=1)
6496

65-
def _train_task(self, train_loader, val_loader):
66-
for _ in trange(self._epochs):
67-
self._scheduler.step()
68-
for _, inputs, targets in train_loader:
69-
self._optimizer.zero_grad()
97+
return 100 * round(np.mean(ypred == ytrue), 3)
7098

71-
inputs, targets = inputs.to(self._device), targets.to(self._device)
99+
def _forward_loss(self, training_network, inputs, targets, memory_flags, metrics):
100+
inputs, targets = inputs.to(self._device), targets.to(self._device)
101+
onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device)
72102

73-
logits = self.forward(inputs)
74-
loss = F.cross_entropy(logits, targets)
75-
loss.backward()
76-
self._optimizer.step()
103+
outputs = training_network(inputs)
77104

78-
def _after_task(self, data_loader):
79-
pass
105+
loss = self._compute_loss(inputs, outputs, targets, onehot_targets, memory_flags, metrics)
80106

81-
def _eval_task(self, loader):
82-
ypred = []
83-
ytrue = []
107+
if not utils.check_loss(loss):
108+
raise ValueError("Loss became invalid ({}).".format(loss))
84109

85-
for _, inputs, targets in loader:
86-
inputs = inputs.to(self._device)
87-
logits = self.forward(inputs)
88-
preds = logits.argmax(dim=1).cpu().numpy()
110+
metrics["loss"] += loss.item()
89111

90-
ypred.extend(preds)
91-
ytrue.extend(targets)
112+
return loss
92113

93-
ypred, ytrue = np.array(ypred), np.array(ytrue)
94-
print(np.bincount(ypred))
95-
return ypred, ytrue
114+
def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags, metrics):
115+
logits = outputs["logits"]
96116

97-
def _add_n_classes(self, n):
98-
self._n_classes += n
99-
100-
self._classifiers.append(nn.Linear(
101-
self._features_extractor.out_dim, self._task_size,
102-
bias=False
103-
).to(self._device))
104-
nn.init.kaiming_normal_(self._classifiers[-1].weight)
105-
self.add_module("clf_" + str(self._n_classes), self._classifiers[-1])
106-
107-
for param in self._features_extractor.parameters():
108-
param.requires_grad = False
109-
110-
for clf in self._classifiers[:-1]:
111-
for param in clf.parameters():
112-
param.requires_grad = False
113-
for param in self._classifiers[-1].parameters():
114-
for param in clf.parameters():
115-
param.requires_grad = True
117+
return F.cross_entropy(logits, targets)

0 commit comments

Comments
 (0)