|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 | import torch |
5 | | -from torch import nn |
6 | 5 | from torch.nn import functional as F |
7 | | -from tqdm import trange |
8 | 6 |
|
9 | | -from inclearn import factory, utils |
10 | | -from inclearn.models.base import IncrementalLearner |
| 7 | +from inclearn.lib import factory, loops, network, utils |
| 8 | +from inclearn.models import IncrementalLearner |
11 | 9 |
|
12 | | -LOGGER = logging.Logger("IncLearn", level="INFO") |
| 10 | +logger = logging.getLogger(__name__) |
13 | 11 |
|
14 | 12 |
|
15 | 13 | class FixedRepresentation(IncrementalLearner): |
16 | | - """Base incremental learner. |
17 | | -
|
18 | | - Methods are called in this order (& repeated for each new task): |
19 | | -
|
20 | | - 1. set_task_info |
21 | | - 2. before_task |
22 | | - 3. train_task |
23 | | - 4. after_task |
24 | | - 5. eval_task |
25 | | - """ |
26 | 14 |
|
27 | 15 | def __init__(self, args): |
28 | | - super().__init__() |
29 | | - |
30 | | - self._epochs = 70 |
31 | | - |
32 | | - self._n_classes = args["increment"] |
33 | | - self._device = args["device"] |
| 16 | + self._device = args["device"][0] |
| 17 | + self._multiple_devices = args["device"] |
| 18 | + |
| 19 | + self._opt_name = args["optimizer"] |
| 20 | + self._lr = args["lr"] |
| 21 | + self._lr_decay = args["lr_decay"] |
| 22 | + self._weight_decay = args["weight_decay"] |
| 23 | + self._n_epochs = args["epochs"] |
| 24 | + self._scheduling = args["scheduling"] |
| 25 | + |
| 26 | + logger.info("Initializing FixedRepresentation") |
| 27 | + |
| 28 | + self._network = network.BasicNet( |
| 29 | + args["convnet"], |
| 30 | + convnet_kwargs=args.get("convnet_config", {}), |
| 31 | + classifier_kwargs=args.get( |
| 32 | + "classifier_config", { |
| 33 | + "type": "fc", |
| 34 | + "use_bias": True, |
| 35 | + "use_multi_fc": True |
| 36 | + } |
| 37 | + ), |
| 38 | + device=self._device |
| 39 | + ) |
| 40 | + |
| 41 | + self._n_classes = 0 |
| 42 | + self._old_model = None |
34 | 43 |
|
35 | | - self._features_extractor = factory.get_resnet(args["convnet"], nf=64, |
36 | | - zero_init_residual=True) |
| 44 | + def _before_task(self, data_loader, val_loader): |
| 45 | + self._n_classes += self._task_size |
| 46 | + self._network.add_classes(self._task_size) |
37 | 47 |
|
38 | | - self._classifiers = [nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=False).to(self._device)] |
39 | | - torch.nn.init.kaiming_normal_(self._classifiers[0].weight) |
40 | | - self.add_module("clf_" + str(self._n_classes), self._classifiers[0]) |
| 48 | + self._optimizer = factory.get_optimizer( |
| 49 | + self._network.classifier.classifier[-1].parameters(), self._opt_name, self._lr, |
| 50 | + self._weight_decay |
| 51 | + ) |
| 52 | + if self._scheduling is None: |
| 53 | + self._scheduler = None |
| 54 | + else: |
| 55 | + self._scheduler = torch.optim.lr_scheduler.MultiStepLR( |
| 56 | + self._optimizer, self._scheduling, gamma=self._lr_decay |
| 57 | + ) |
41 | 58 |
|
42 | | - self.to(self._device) |
| 59 | + def _train_task(self, train_loader, val_loader): |
| 60 | + loops.single_loop( |
| 61 | + train_loader, |
| 62 | + val_loader, |
| 63 | + self._multiple_devices, |
| 64 | + self._network, |
| 65 | + self._n_epochs, |
| 66 | + self._optimizer, |
| 67 | + scheduler=self._scheduler, |
| 68 | + train_function=self._forward_loss, |
| 69 | + eval_function=self._accuracy, |
| 70 | + task=self._task, |
| 71 | + n_tasks=self._n_tasks |
| 72 | + ) |
| 73 | + |
| 74 | + def _after_task(self, inc_dataset): |
| 75 | + self._old_model = self._network.copy().freeze().eval().to(self._device) |
| 76 | + self._network.on_task_end() |
43 | 77 |
|
44 | | - def forward(self, x): |
45 | | - feats = self._features_extractor(x) |
| 78 | + def _eval_task(self, loader): |
| 79 | + ypred, ytrue = [], [] |
46 | 80 |
|
47 | | - logits = [] |
48 | | - for clf in self._classifiers: |
49 | | - logits.append(clf(feats)) |
| 81 | + for input_dict in loader: |
| 82 | + with torch.no_grad(): |
| 83 | + logits = self._network(input_dict["inputs"].to(self._device))["logits"] |
50 | 84 |
|
51 | | - return torch.cat(logits, dim=1) |
| 85 | + ytrue.append(input_dict["targets"].numpy()) |
| 86 | + ypred.append(torch.softmax(logits, dim=1).cpu().numpy()) |
52 | 87 |
|
53 | | - def _before_task(self, data_loader, val_loader): |
54 | | - if self._task != 0: |
55 | | - self._add_n_classes(self._task_size) |
| 88 | + ytrue = np.concatenate(ytrue) |
| 89 | + ypred = np.concatenate(ypred) |
56 | 90 |
|
57 | | - self._optimizer = factory.get_optimizer( |
58 | | - filter(lambda x: x.requires_grad, self.parameters()), |
59 | | - "sgd", 0.1) |
60 | | - self._scheduler = torch.optim.lr_scheduler.MultiStepLR(self._optimizer, [50, 60], gamma=0.2) |
| 91 | + return ypred, ytrue |
61 | 92 |
|
62 | | - def _get_params(self): |
63 | | - return [self._features_extractor.parameters()] |
| 93 | + def _accuracy(self, loader): |
| 94 | + ypred, ytrue = self._eval_task(loader) |
| 95 | + ypred = ypred.argmax(dim=1) |
64 | 96 |
|
65 | | - def _train_task(self, train_loader, val_loader): |
66 | | - for _ in trange(self._epochs): |
67 | | - self._scheduler.step() |
68 | | - for _, inputs, targets in train_loader: |
69 | | - self._optimizer.zero_grad() |
| 97 | + return 100 * round(np.mean(ypred == ytrue), 3) |
70 | 98 |
|
71 | | - inputs, targets = inputs.to(self._device), targets.to(self._device) |
| 99 | + def _forward_loss(self, training_network, inputs, targets, memory_flags, metrics): |
| 100 | + inputs, targets = inputs.to(self._device), targets.to(self._device) |
| 101 | + onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device) |
72 | 102 |
|
73 | | - logits = self.forward(inputs) |
74 | | - loss = F.cross_entropy(logits, targets) |
75 | | - loss.backward() |
76 | | - self._optimizer.step() |
| 103 | + outputs = training_network(inputs) |
77 | 104 |
|
78 | | - def _after_task(self, data_loader): |
79 | | - pass |
| 105 | + loss = self._compute_loss(inputs, outputs, targets, onehot_targets, memory_flags, metrics) |
80 | 106 |
|
81 | | - def _eval_task(self, loader): |
82 | | - ypred = [] |
83 | | - ytrue = [] |
| 107 | + if not utils.check_loss(loss): |
| 108 | + raise ValueError("Loss became invalid ({}).".format(loss)) |
84 | 109 |
|
85 | | - for _, inputs, targets in loader: |
86 | | - inputs = inputs.to(self._device) |
87 | | - logits = self.forward(inputs) |
88 | | - preds = logits.argmax(dim=1).cpu().numpy() |
| 110 | + metrics["loss"] += loss.item() |
89 | 111 |
|
90 | | - ypred.extend(preds) |
91 | | - ytrue.extend(targets) |
| 112 | + return loss |
92 | 113 |
|
93 | | - ypred, ytrue = np.array(ypred), np.array(ytrue) |
94 | | - print(np.bincount(ypred)) |
95 | | - return ypred, ytrue |
| 114 | + def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags, metrics): |
| 115 | + logits = outputs["logits"] |
96 | 116 |
|
97 | | - def _add_n_classes(self, n): |
98 | | - self._n_classes += n |
99 | | - |
100 | | - self._classifiers.append(nn.Linear( |
101 | | - self._features_extractor.out_dim, self._task_size, |
102 | | - bias=False |
103 | | - ).to(self._device)) |
104 | | - nn.init.kaiming_normal_(self._classifiers[-1].weight) |
105 | | - self.add_module("clf_" + str(self._n_classes), self._classifiers[-1]) |
106 | | - |
107 | | - for param in self._features_extractor.parameters(): |
108 | | - param.requires_grad = False |
109 | | - |
110 | | - for clf in self._classifiers[:-1]: |
111 | | - for param in clf.parameters(): |
112 | | - param.requires_grad = False |
113 | | - for param in self._classifiers[-1].parameters(): |
114 | | - for param in clf.parameters(): |
115 | | - param.requires_grad = True |
| 117 | + return F.cross_entropy(logits, targets) |
0 commit comments