Skip to content

Commit b0b3eb9

Browse files
[models] Basemodel returns a dict intead of a list.
1 parent d59a44d commit b0b3eb9

File tree

4 files changed

+23
-16
lines changed

4 files changed

+23
-16
lines changed

inclearn/lib/network/basenet.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,18 @@ def forward(self, x):
100100
selected_outputs = outputs[0] if self.classifier_no_act else outputs[1]
101101
logits = self.classifier(self.dropout(selected_outputs))
102102

103+
outputs = {"logits": logits}
104+
103105
if self.return_features:
104-
to_return = []
105106
if self.extract_no_act:
106-
to_return.append(outputs[0])
107+
outputs["features"] = outputs[0]
107108
else:
108-
to_return.append(outputs[1])
109+
outputs["features"] = outputs[1]
109110

110-
to_return.append(logits)
111111
if self.attention_hook:
112-
to_return.append(outputs[2])
112+
outputs["attention_maps"] = outputs[2]
113113

114-
return to_return
115-
return logits
114+
return outputs
116115

117116
def post_process(self, x):
118117
if self.post_processor is None:

inclearn/models/bic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,14 @@ def _eval_task(self, loader):
8282

8383
return ypred, ytrue
8484

85-
def _compute_loss(self, inputs, logits, targets, onehot_targets, memory_flags):
85+
def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
86+
logits = outputs["logits"]
87+
8688
loss = F.cross_entropy(logits, targets)
8789

8890
if self._old_model is not None:
89-
old_targets = self._old_model.post_process(self._old_model(inputs)).detach()
91+
with torch.no_grad():
92+
old_targets = self._old_model.post_process(self._old_model(inputs)["logits"])
9093

9194
loss += F.binary_cross_entropy_with_logits(
9295
logits[..., :-self._task_size] / self._temperature,

inclearn/models/icarl.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ def _forward_loss(self, training_network, inputs, targets, memory_flags):
233233
inputs, targets = inputs.to(self._device), targets.to(self._device)
234234
onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device)
235235

236-
logits = training_network(inputs)
236+
outputs = training_network(inputs)
237237

238-
loss = self._compute_loss(inputs, logits, targets, onehot_targets, memory_flags)
238+
loss = self._compute_loss(inputs, outputs, targets, onehot_targets, memory_flags)
239239

240240
if not utils._check_loss(loss):
241241
pdb.set_trace()
@@ -270,11 +270,14 @@ def _eval_task(self, data_loader):
270270
# Private API
271271
# -----------
272272

273-
def _compute_loss(self, inputs, logits, targets, onehot_targets, memory_flags):
273+
def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
274+
logits = outputs["logits"]
275+
274276
if self._old_model is None:
275277
loss = F.binary_cross_entropy_with_logits(logits, onehot_targets)
276278
else:
277-
old_targets = torch.sigmoid(self._old_model(inputs).detach())
279+
with torch.no_grad():
280+
old_targets = torch.sigmoid(self._old_model(inputs)["outputs"])
278281

279282
new_targets = onehot_targets.clone()
280283
new_targets[..., :-self._task_size] = old_targets

inclearn/models/still.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ def _before_task(self, train_loader, val_loader):
351351
else:
352352
self._class_weights = None
353353

354-
def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory_flags):
355-
features, logits, atts = features_logits
354+
def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
355+
features, logits, atts = outputs["features"], outputs["logits"], outputs["attention_maps"]
356356

357357
if self._post_processing_type is None:
358358
scaled_logits = self._network.post_process(logits)
@@ -361,7 +361,9 @@ def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory
361361

362362
if self._old_model is not None:
363363
with torch.no_grad():
364-
old_features, old_logits, old_atts = self._old_model(inputs)
364+
old_outputs = self._old_model(inputs)
365+
old_features = old_outputs["features"]
366+
old_atts = old_outputs["attention_maps"]
365367

366368
if self._ams_config:
367369
ams_config = copy.deepcopy(self._ams_config)

0 commit comments

Comments
 (0)