Skip to content

Commit d0d4282

Browse files
[still] Trim useless options to keep paper archi only.
1 parent c718938 commit d0d4282

File tree

2 files changed

+1
-259
lines changed

2 files changed

+1
-259
lines changed

inclearn/models/icarl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,6 @@ def _forward_loss(self, training_network, inputs, targets, memory_flags):
233233
inputs, targets = inputs.to(self._device), targets.to(self._device)
234234
onehot_targets = utils.to_onehot(targets, self._n_classes).to(self._device)
235235

236-
if self._random_noise_config:
237-
random_noise = torch.randn(self._random_noise_config["nb_per_batch"], *inputs.shape[1:])
238-
inputs = torch.cat((inputs, random_noise.to(self._device)))
239-
240236
logits = training_network(inputs)
241237

242238
loss = self._compute_loss(inputs, logits, targets, onehot_targets, memory_flags)

inclearn/models/still.py

Lines changed: 1 addition & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,21 @@ def __init__(self, args):
3939
self._herding_selection = args.get("herding_selection", "icarl")
4040
self._n_classes = 0
4141

42-
self._use_mimic_score = args.get("mimic_score", False)
4342
self._less_forget_config = args.get("less_forget", {})
4443
assert isinstance(self._less_forget_config, dict)
4544

4645
self._lambda_schedule = args.get("lambda_schedule", False)
47-
self._ranking_loss = args.get("ranking_loss", {})
48-
49-
self._relative_teachers_config = args.get("relative_teachers", {})
5046

5147
self._gor_config = args.get("gor_config", {})
5248

5349
self._ams_config = args.get("adaptative_margin_softmax", {})
5450
self._softmax_ce = args.get("softmax_ce", False)
5551

5652
self._attention_residual_config = args.get("attention_residual", {})
57-
assert isinstance(self._attention_residual_config, dict), "ra need to be dict"
5853

5954
self._perceptual_features = args.get("perceptual_features")
6055
self._perceptual_style = args.get("perceptual_style")
6156

62-
self._use_teacher_confidence = args.get("teacher_confidence", False)
63-
6457
self._groupwise_factors = args.get("groupwise_factors", {})
6558
self._softtriple_config = args.get("softriple_regularizer", {})
6659

@@ -78,24 +71,15 @@ def __init__(self, args):
7871
self._evaluation_type = args.get("eval_type", "icarl")
7972
self._evaluation_config = args.get("evaluation_config", {})
8073

81-
self._weights_orthogonality = args.get("weights_orthogonality")
82-
self._orthoreg_config = args.get("orthoreg_config", {})
83-
self._dso_config = args.get("dso_config", {})
84-
self._mc_config = args.get("mc_config", {})
85-
self._srip_config = args.get("srip_config", {})
8674
self._double_margin_reg = args.get("double_margin_reg", {})
8775

8876
self._save_model = args["save_model"]
8977

90-
self._harmonic_embeddings = args.get("harmonic_embeddings", {})
91-
9278
self._rotations_config = args.get("rotations_config", {})
9379

9480
self._eval_every_x_epochs = args.get("eval_every_x_epochs")
9581
self._early_stopping = args.get("early_stopping", {})
9682

97-
self._random_noise_config = args.get("random_noise_config", {})
98-
9983
classifier_kwargs = args.get("classifier_config", {})
10084
self._network = network.BasicNet(
10185
args["convnet"],
@@ -127,12 +111,6 @@ def __init__(self, args):
127111
self._herding_compressed_indexes = []
128112

129113
self._weight_generation = args.get("weight_generation")
130-
self._compressed_memory = args.get("compressed_memory")
131-
self._alternate_training_config = args.get("alternate_training")
132-
133-
self._compressed_data = {}
134-
self._compressed_targets = {}
135-
self._compressed_means = []
136114

137115
self._saved_network = None
138116
self._post_processing_type = None
@@ -143,9 +121,7 @@ def __init__(self, args):
143121
@property
144122
def _memory_per_class(self):
145123
"""Returns the number of examplars per class."""
146-
if self._compressed_memory:
147-
return self._compressed_memory["quantity_images"]
148-
elif self._fixed_memory:
124+
if self._fixed_memory:
149125
return self._memory_size // self._total_n_classes
150126
return self._memory_size // self._n_classes
151127

@@ -154,9 +130,6 @@ def _train_task(self, train_loader, val_loader):
154130
if p.requires_grad:
155131
p.register_hook(lambda grad: torch.clamp(grad, -5., 5.))
156132

157-
if self._alternate_training_config and self._task != 0:
158-
return self._alternate_training(train_loader, val_loader)
159-
160133
logger.debug("nb {}.".format(len(train_loader.dataset)))
161134
self._training_step(train_loader, val_loader, 0, self._n_epochs)
162135

@@ -224,93 +197,10 @@ def weight_decay(self):
224197
)
225198
)
226199

227-
def _alternate_training(self, train_loader, val_loader):
228-
for phase in self._alternate_training_config:
229-
if phase["update_theta"]:
230-
logger.info("Updating theta")
231-
for class_index in range(self._n_classes - self._task_size, self._n_classes):
232-
_, loader = self.inc_dataset.get_custom_loader([class_index])
233-
features, _ = utils.extract_features(self._network, loader)
234-
features = F.normalize(torch.from_numpy(features), p=2, dim=1)
235-
mean = torch.mean(features, dim=0)
236-
mean = F.normalize(mean, dim=0, p=2)
237-
238-
self._network.classifier.weights.data[class_index] = mean.to(self._device)
239-
240-
self._network.freeze(trainable=phase["train_f"], model="convnet")
241-
self._network.freeze(trainable=phase["train_theta"], model="classifier")
242-
logger.info("Freeze convnet=" + str(phase["train_f"]))
243-
logger.info("Freeze classifier=" + str(phase["train_theta"]))
244-
245-
self._optimizer = factory.get_optimizer(
246-
self._network.parameters(), self._opt_name, self._lr, self._weight_decay
247-
)
248-
self._training_step(train_loader, val_loader, 0, phase["nb_epochs"])
249-
250200
def _after_task(self, inc_dataset):
251201
self._monitor_scale()
252202
super()._after_task(inc_dataset)
253203

254-
if self._compressed_memory:
255-
self.add_compressed_memory()
256-
257-
def add_compressed_memory(self):
258-
_, _, self._herding_compressed_indexes, _ = self.build_examplars(
259-
self.inc_dataset, self._herding_compressed_indexes, self.quantity_compressed_embeddings
260-
)
261-
262-
# Computing the embeddings of only the current task images:
263-
for class_index in range(self._n_classes - self._task_size, self._n_classes):
264-
_, loader = self.inc_dataset.get_custom_loader([class_index])
265-
features, targets = utils.extract_features(self._network, loader)
266-
267-
selected_features = features[self._herding_compressed_indexes[class_index]]
268-
selected_targets = targets[self._herding_compressed_indexes[class_index]]
269-
270-
self._compressed_means.append(np.mean(selected_features, axis=0))
271-
272-
self._compressed_data[class_index] = selected_features
273-
self._compressed_targets[class_index] = selected_targets
274-
275-
logger.info(
276-
"{} compressed memory, or {} per class.".format(
277-
sum(len(x) for x in self._compressed_data.values()),
278-
self.quantity_compressed_embeddings
279-
)
280-
)
281-
282-
# Taking in account the mean shift of the class:
283-
if self._compressed_memory["mean_shift"]:
284-
logger.info("Computing mean shift")
285-
for class_index in range(self._n_classes - self._task_size):
286-
class_memory, class_targets = utils.select_class_samples(
287-
self._data_memory, self._targets_memory, class_index
288-
)
289-
290-
_, loader = self.inc_dataset.get_custom_loader(
291-
[], memory=((class_memory, class_targets))
292-
)
293-
features, _ = utils.extract_features(self._network, loader)
294-
features_mean = np.mean(features, axis=0)
295-
296-
diff_mean = features_mean - self._compressed_means[class_index]
297-
298-
self._compressed_data[class_index] += diff_mean
299-
300-
for class_index in range(self._n_classes):
301-
indexes = np.random.permutation(self.quantity_compressed_embeddings)
302-
self._compressed_data[class_index] = self._compressed_data[class_index][indexes]
303-
304-
@property
305-
def quantity_compressed_embeddings(self):
306-
assert self._compressed_memory
307-
308-
embed_size = 64 * 16
309-
image_size = 32 * 32 * 3 * 8
310-
total_mem = image_size * 20
311-
312-
return (total_mem - image_size * self._compressed_memory["quantity_images"]) // embed_size
313-
314204
def _monitor_scale(self):
315205
if "scale" not in self._args["_logs"]:
316206
self._args["_logs"]["scale"] = []
@@ -454,49 +344,16 @@ def _before_task(self, train_loader, val_loader):
454344
task=self._task
455345
)
456346

457-
if self._compressed_memory:
458-
self._compressed_iterator = 0
459-
self._compressed_step = self.quantity_compressed_embeddings // len(train_loader)
460-
461347
if self._class_weights_config:
462348
self._class_weights = torch.tensor(
463349
data.get_class_weights(train_loader.dataset, **self._class_weights_config)
464350
).to(self._device)
465351
else:
466352
self._class_weights = None
467353

468-
def _sample_compressed(self):
469-
features, logits, targets = [], [], []
470-
471-
low_index = self._compressed_iterator * self._compressed_step
472-
self._compressed_iterator += 1
473-
high_index = self._compressed_iterator * self._compressed_step
474-
475-
for class_index in self._compressed_data.keys():
476-
f = self._compressed_data[class_index][low_index:high_index]
477-
t = self._compressed_targets[class_index][low_index:high_index]
478-
479-
f = torch.tensor(f).to(self._device)
480-
t = torch.tensor(t).to(self._device)
481-
482-
logits.append(self._network.classifier(f))
483-
features.append(f)
484-
targets.append(t)
485-
486-
return torch.cat(features), torch.cat(logits), torch.cat(targets)
487-
488354
def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory_flags):
489355
features, logits, atts = features_logits
490356

491-
if self._random_noise_config:
492-
logits = logits[:-self._random_noise_config["nb_per_batch"]]
493-
494-
if self._compressed_memory and len(self._compressed_data) > 0:
495-
c_f, c_l, c_t = self._sample_compressed()
496-
features = torch.cat((features, c_f))
497-
logits = torch.cat((logits, c_l))
498-
targets = torch.cat((targets, c_t))
499-
500357
if self._post_processing_type is None:
501358
scaled_logits = self._network.post_process(logits)
502359
else:
@@ -506,13 +363,6 @@ def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory
506363
with torch.no_grad():
507364
old_features, old_logits, old_atts = self._old_model(inputs)
508365

509-
if self._compressed_memory and len(self._compressed_data) > 0:
510-
old_features = torch.cat((old_features, c_f))
511-
old_logits = torch.cat((old_logits, self._old_model.classifier(c_f)))
512-
513-
if self._random_noise_config:
514-
old_logits = old_logits[:-self._random_noise_config["nb_per_batch"]]
515-
516366
if self._ams_config:
517367
ams_config = copy.deepcopy(self._ams_config)
518368
if self._network.post_processor:
@@ -526,90 +376,21 @@ def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory
526376
**ams_config
527377
)
528378
self._metrics["ams"] += loss.item()
529-
elif self._use_npair:
530-
loss = losses.n_pair_loss(logits, targets)
531-
self._metrics["npair"] += loss.item()
532-
elif self._proxy_nca_config:
533-
if self._network.post_processor:
534-
self._proxy_nca_config["s"] = self._network.post_processor.factor
535-
536-
loss = losses.proxy_nca_github(
537-
scaled_logits, targets, self._n_classes, **self._proxy_nca_config
538-
)
539-
self._metrics["nca"] += loss.item()
540-
elif self._triplet_config:
541-
loss, percent_violated = losses.triplet_loss(
542-
features,
543-
targets,
544-
**self._triplet_config,
545-
harmonic_embeddings=self._harmonic_embeddings,
546-
old_features=old_features if self._old_model else None,
547-
memory_flags=memory_flags,
548-
epoch_percent=self._epoch_percent
549-
)
550-
551-
self._metrics["tri"] += loss.item()
552-
self._metrics["violated"] += percent_violated
553379
elif self._softmax_ce:
554380
loss = F.cross_entropy(scaled_logits, targets)
555381
self._metrics["cce"] += loss.item()
556-
else:
557-
if self._use_teacher_confidence and self._old_model is not None:
558-
loss = losses.cross_entropy_teacher_confidence(
559-
scaled_logits, targets, F.softmax(old_logits, dim=1), memory_flags
560-
)
561-
self._metrics["clf_conf"] += loss.item()
562-
else:
563-
loss = F.cross_entropy(scaled_logits, targets)
564-
self._metrics["clf"] += loss.item()
565382

566383
# ----------------------
567384
# Regularization losses:
568385
# ----------------------
569386

570-
if self._weights_orthogonality is not None:
571-
margin = self._weights_orthogonality.get("margin")
572-
ortho_loss = losses.weights_orthogonality(
573-
self._network.classifier.weights, margin=margin
574-
)
575-
loss += ortho_loss
576-
self._metrics["ortho"] += ortho_loss.item()
577-
578387
if self._gor_config:
579388
gor_loss = losses.global_orthogonal_regularization(
580389
features, targets, self._n_classes - self._task_size, **self._gor_config
581390
)
582391
self._metrics["gor"] += gor_loss.item()
583392
loss += gor_loss
584393

585-
if self._orthoreg_config:
586-
orthoreg_loss = losses.ortho_reg(
587-
self._network.classifier.weights, self._orthoreg_config
588-
)
589-
self._metrics["orthoreg"] += orthoreg_loss.item()
590-
loss += orthoreg_loss
591-
592-
if self._dso_config:
593-
dso_loss = losses.double_soft_orthoreg(
594-
self._network.classifier.weights, self._dso_config
595-
)
596-
self._metrics["dso"] += dso_loss.item()
597-
loss += dso_loss
598-
599-
if self._mc_config:
600-
mc_loss = losses.mutual_coherence_regularization(
601-
self._network.classifier.weights, self._mc_config
602-
)
603-
self._metrics["mc"] += mc_loss.item()
604-
loss += mc_loss
605-
606-
if self._srip_config:
607-
srip_loss = losses.spectral_restricted_isometry_property_regularization(
608-
self._network.classifier.weights, self._srip_config
609-
)
610-
self._metrics["srip"] += srip_loss.item()
611-
loss += srip_loss
612-
613394
if self._softtriple_config:
614395
st_reg = losses.softriple_regularizer(
615396
self._network.classifier.weights, self._softtriple_config
@@ -644,41 +425,6 @@ def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory
644425
distil_loss = factor * losses.embeddings_similarity(old_features, features)
645426
loss += distil_loss
646427
self._metrics["lf"] += distil_loss.item()
647-
elif self._use_mimic_score:
648-
old_class_logits = logits[..., :self._n_classes - self._task_size]
649-
old_class_old_logits = old_logits[..., :self._n_classes - self._task_size]
650-
651-
mimic_loss = F.mse_loss(old_class_logits, old_class_old_logits)
652-
mimic_loss *= (self._n_classes - self._task_size)
653-
loss += mimic_loss
654-
self._metrics["mimic"] += mimic_loss.item()
655-
656-
if self._ranking_loss:
657-
ranking_loss = self._ranking_loss["factor"] * losses.ucir_ranking(
658-
logits,
659-
targets,
660-
self._n_classes,
661-
self._task_size,
662-
nb_negatives=self._ranking_loss["nb_negatives"],
663-
margin=self._ranking_loss["margin"]
664-
)
665-
loss += ranking_loss
666-
self._metrics["rank"] += ranking_loss.item()
667-
668-
if self._relative_teachers_config:
669-
if self._relative_teachers_config["select"] == "old":
670-
indexes_old = memory_flags.eq(1.)
671-
old_features_memory = old_features[indexes_old]
672-
new_features_memory = features[indexes_old]
673-
else:
674-
old_features_memory = old_features
675-
new_features_memory = features
676-
677-
relative_t_loss = losses.relative_teacher_distances(
678-
old_features_memory, new_features_memory, **self._relative_teachers_config
679-
)
680-
loss += self._relative_teachers_config["factor"] * relative_t_loss
681-
self._metrics["rel"] += relative_t_loss.item()
682428

683429
if self._attention_residual_config:
684430
if self._attention_residual_config.get("scheduled_factor", False):

0 commit comments

Comments
 (0)