Skip to content

Commit 0621230

Browse files
[lwm] Add draft of Learning without Memorizing.
1 parent 14938b5 commit 0621230

File tree

3 files changed

+266
-18
lines changed

3 files changed

+266
-18
lines changed

inclearn/lib/losses/distillation.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
from torch.nn import functional as F
33

4+
from inclearn.lib import vizualization
5+
46

57
def mer_loss(new_logits, old_logits):
68
"""Distillation loss that is less important if the new model is unconfident.
@@ -145,7 +147,7 @@ def perceptual_features_reconstruction(list_attentions_a, list_attentions_b, fac
145147
a = F.normalize(a, p=2, dim=-1)
146148
b = F.normalize(b, p=2, dim=-1)
147149

148-
layer_loss = (F.pairwise_distance(a, b, p=2) ** 2) / (c * w * h)
150+
layer_loss = (F.pairwise_distance(a, b, p=2)**2) / (c * w * h)
149151
loss += torch.mean(layer_loss)
150152

151153
return factor * (loss / len(list_attentions_a))
@@ -163,7 +165,37 @@ def perceptual_style_reconstruction(list_attentions_a, list_attentions_b, factor
163165
gram_a = torch.bmm(a, a.transpose(2, 1)) / (c * w * h)
164166
gram_b = torch.bmm(b, b.transpose(2, 1)) / (c * w * h)
165167

166-
layer_loss = torch.frobenius_norm(gram_a - gram_b, dim=(1, 2)) ** 2
168+
layer_loss = torch.frobenius_norm(gram_a - gram_b, dim=(1, 2))**2
167169
loss += layer_loss.mean()
168170

169171
return factor * (loss / len(list_attentions_a))
172+
173+
174+
def gradcam_distillation(gradients_a, gradients_b, activations_a, activations_b, factor=1):
175+
"""Distillation loss between gradcam-generated attentions of two models.
176+
177+
References:
178+
* Dhar et al.
179+
Learning without Memorizing
180+
CVPR 2019
181+
182+
:param base_logits: [description]
183+
:param list_attentions_a: [description]
184+
:param list_attentions_b: [description]
185+
:param factor: [description], defaults to 1
186+
:return: [description]
187+
"""
188+
attentions_a = gradients_a * activations_a
189+
attentions_b = gradients_b * activations_b
190+
191+
assert len(attentions_a.shape) == len(attentions_b.shape) == 4
192+
assert attentions_a.shape == attentions_b.shape
193+
194+
batch_size = attentions_a.shape[0]
195+
196+
flat_attention_a = F.normalize(attentions_a.view(batch_size, -1), p=2, dim=-1)
197+
flat_attention_b = F.normalize(attentions_b.view(batch_size, -1), p=2, dim=-1)
198+
199+
distances = F.pairwise_distance(flat_attention_a, flat_attention_b, p=1)
200+
201+
return factor * torch.mean(distances)

inclearn/lib/network/basenet.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
classifier_no_act=False,
2727
attention_hook=False,
2828
rotations_predictor=False,
29+
gradcam_hook=False,
2930
dropout=0.
3031
):
3132
super(BasicNet, self).__init__()
@@ -74,12 +75,17 @@ def __init__(
7475
self.extract_no_act = extract_no_act
7576
self.classifier_no_act = classifier_no_act
7677
self.attention_hook = attention_hook
78+
self.gradcam_hook = gradcam_hook
7779
self.device = device
7880

81+
if self.gradcam_hook:
82+
self._hooks = [None, None]
83+
logger.info("Setting gradcam hook for gradients + activations of last conv.")
84+
self.set_gradcam_hook()
7985
if self.extract_no_act:
80-
print("Features will be extracted without the last ReLU.")
86+
logger.info("Features will be extracted without the last ReLU.")
8187
if self.classifier_no_act:
82-
print("No ReLU will be applied on features before feeding the classifier.")
88+
logger.info("No ReLU will be applied on features before feeding the classifier.")
8389

8490
self.to(self.device)
8591

@@ -96,20 +102,19 @@ def on_epoch_end(self):
96102
self.post_processor.on_epoch_end()
97103

98104
def forward(self, x):
99-
outputs = self.convnet(x, attention_hook=self.attention_hook)
100-
selected_outputs = outputs[0] if self.classifier_no_act else outputs[1]
101-
logits = self.classifier(self.dropout(selected_outputs))
105+
outputs = self.convnet(x)
102106

103-
outputs = {"logits": logits}
107+
if self.classifier_no_act:
108+
selected_features = outputs["raw_features"]
109+
else:
110+
selected_features = outputs["features"]
111+
logits = self.classifier(self.dropout(selected_features))
104112

105-
if self.return_features:
106-
if self.extract_no_act:
107-
outputs["features"] = outputs[0]
108-
else:
109-
outputs["features"] = outputs[1]
113+
outputs["logits"] = logits
110114

111-
if self.attention_hook:
112-
outputs["attention_maps"] = outputs[2]
115+
if self.gradcam_hook:
116+
outputs["gradcam_gradients"] = self._gradcam_gradients
117+
outputs["gradcam_activations"] = self._gradcam_activations
113118

114119
return outputs
115120

@@ -135,10 +140,10 @@ def add_custom_weights(self, weights):
135140
self.classifier.add_custom_weights(weights)
136141

137142
def extract(self, x):
138-
raw_features, features = self.convnet(x)
143+
outputs = self.convnet(x)
139144
if self.extract_no_act:
140-
return raw_features
141-
return features
145+
return outputs["raw_features"]
146+
return outputs["features"]
142147

143148
def predict_rotations(self, inputs):
144149
if self.rotations_predictor is None:
@@ -160,6 +165,9 @@ def freeze(self, trainable=False, model="all"):
160165

161166
for param in model.parameters():
162167
param.requires_grad = trainable
168+
if self.gradcam_hook and model == "convnet":
169+
for param in self.convnet.last_conv.parameters():
170+
param.requires_grad = True
163171

164172
if not trainable:
165173
model.eval()
@@ -185,3 +193,24 @@ def copy(self):
185193
@property
186194
def n_classes(self):
187195
return self.classifier.n_classes
196+
197+
def unset_gradcam_hook(self):
198+
self._hooks[0].remove()
199+
self._hooks[1].remove()
200+
self._hooks[0] = None
201+
self._hooks[1] = None
202+
self._gradcam_gradients, self._gradcam_activations = [None], [None]
203+
204+
def set_gradcam_hook(self):
205+
self._gradcam_gradients, self._gradcam_activations = [None], [None]
206+
207+
def backward_hook(module, grad_input, grad_output):
208+
self._gradcam_gradients[0] = grad_output[0]
209+
return None
210+
211+
def forward_hook(module, input, output):
212+
self._gradcam_activations[0] = output
213+
return None
214+
215+
self._hooks[0] = self.convnet.last_conv.register_backward_hook(backward_hook)
216+
self._hooks[1] = self.convnet.last_conv.register_forward_hook(forward_hook)

inclearn/models/lwm.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import logging
2+
import pdb
3+
4+
import numpy as np
5+
import torch
6+
from torch.nn import functional as F
7+
8+
from inclearn.lib import factory, loops, losses, network, utils
9+
from inclearn.models import IncrementalLearner
10+
11+
EPSILON = 1e-8
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class LwM(IncrementalLearner):
17+
18+
def __init__(self, args):
19+
self._device = args["device"][0]
20+
self._multiple_devices = args["device"]
21+
22+
self._opt_name = args["optimizer"]
23+
self._lr = args["lr"]
24+
self._lr_decay = args["lr_decay"]
25+
self._weight_decay = args["weight_decay"]
26+
self._n_epochs = args["epochs"]
27+
self._scheduling = args["scheduling"]
28+
29+
self._distillation_config = args["distillation_config"]
30+
self._attention_config = args.get("attention_config", {})
31+
32+
logger.info("Initializing LwM")
33+
34+
self._network = network.BasicNet(
35+
args["convnet"],
36+
convnet_kwargs=args.get("convnet_config", {}),
37+
classifier_kwargs=args.get("classifier_config", {
38+
"type": "fc",
39+
"use_bias": True
40+
}),
41+
device=self._device,
42+
gradcam_hook=True
43+
)
44+
45+
self._n_classes = 0
46+
self._old_model = None
47+
48+
@property
49+
def network(self):
50+
return self._network
51+
52+
@network.setter
53+
def network(self, network_path):
54+
if self._network is not None:
55+
del self._network
56+
57+
def eval(self):
58+
self._network.eval()
59+
60+
def train(self):
61+
self._network.train()
62+
63+
def _before_task(self, data_loader, val_loader):
64+
self._n_classes += self._task_size
65+
self._network.add_classes(self._task_size)
66+
67+
self._optimizer = factory.get_optimizer(
68+
self._network.parameters(), self._opt_name, self._lr, self._weight_decay
69+
)
70+
if self._scheduling is None:
71+
self._scheduler = None
72+
else:
73+
self._scheduler = torch.optim.lr_scheduler.MultiStepLR(
74+
self._optimizer, self._scheduling, gamma=self._lr_decay
75+
)
76+
77+
def _train_task(self, train_loader, val_loader):
78+
loops.single_loop(
79+
train_loader,
80+
val_loader,
81+
self._multiple_devices,
82+
self._network,
83+
self._n_epochs,
84+
self._optimizer,
85+
scheduler=self._scheduler,
86+
train_function=self._forward_loss,
87+
eval_function=self._accuracy,
88+
task=self._task,
89+
n_tasks=self._n_tasks
90+
)
91+
92+
def _after_task(self, inc_dataset):
93+
self._network.zero_grad()
94+
self._network.unset_gradcam_hook()
95+
self._old_model = self._network.copy().eval().to(self._device)
96+
self._network.on_task_end()
97+
98+
self._network.set_gradcam_hook()
99+
self._old_model.set_gradcam_hook()
100+
101+
def _eval_task(self, loader):
102+
ypred, ytrue = [], []
103+
104+
for input_dict in loader:
105+
with torch.no_grad():
106+
logits = self._network(input_dict["inputs"].to(self._device))["logits"]
107+
108+
ytrue.append(input_dict["targets"].numpy())
109+
ypred.append(torch.softmax(logits, dim=1).cpu().numpy())
110+
111+
ytrue = np.concatenate(ytrue)
112+
ypred = np.concatenate(ypred)
113+
114+
return ypred, ytrue
115+
116+
def _accuracy(self, loader):
117+
ypred, ytrue = self._eval_task(loader)
118+
ypred = ypred.argmax(dim=1)
119+
120+
return 100 * round(np.mean(ypred == ytrue), 3)
121+
122+
def _forward_loss(self, training_network, inputs, targets, memory_flags, metrics):
123+
inputs, targets = inputs.to(self._device), targets.to(self._device)
124+
onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device)
125+
126+
outputs = training_network(inputs)
127+
128+
loss = self._compute_loss(inputs, outputs, targets, onehot_targets, memory_flags, metrics)
129+
130+
if not utils.check_loss(loss):
131+
raise ValueError("Loss became invalid ({}).".format(loss))
132+
133+
metrics["loss"] += loss.item()
134+
135+
return loss
136+
137+
def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags, metrics):
138+
logits = outputs["logits"]
139+
140+
if self._old_model is None:
141+
# Classification loss
142+
loss = F.cross_entropy(logits, targets)
143+
metrics["clf"] += loss.item()
144+
else:
145+
self._old_model.zero_grad()
146+
old_outputs = self._old_model(inputs)
147+
old_logits = old_outputs["logits"]
148+
149+
# Classification loss
150+
loss = F.cross_entropy(
151+
logits[..., -self._task_size:], (targets - self._n_classes + self._task_size)
152+
)
153+
metrics["clf"] += loss.item()
154+
155+
# Distillation on probabilities
156+
distill_loss = self._distillation_config["factor"] * F.binary_cross_entropy_with_logits(
157+
logits[..., :-self._task_size], torch.sigmoid(old_logits.detach())
158+
)
159+
metrics["dis"] += distill_loss.item()
160+
loss += distill_loss
161+
162+
# Distillation on gradcam-generated attentions
163+
if self._attention_config:
164+
top_logits_indexes = logits[..., :-self._task_size].argmax(dim=1)
165+
onehot_top_logits = utils.to_onehot(
166+
top_logits_indexes, self._n_classes - self._task_size
167+
).to(self._device)
168+
169+
logits[..., :-self._task_size].backward(
170+
gradient=onehot_top_logits, retain_graph=True, create_graph=True
171+
)
172+
old_logits.backward(
173+
gradient=onehot_top_logits, retain_graph=True, create_graph=True
174+
)
175+
176+
attention_loss = losses.gradcam_distillation(
177+
outputs["gradcam_gradients"][0], old_outputs["gradcam_gradients"][0].detach(),
178+
outputs["gradcam_activations"][0],
179+
old_outputs["gradcam_activations"][0].detach(), **self._attention_config
180+
)
181+
metrics["ad"] += attention_loss.item()
182+
loss += attention_loss
183+
184+
self._old_model.zero_grad()
185+
self._network.zero_grad()
186+
187+
return loss

0 commit comments

Comments
 (0)