@@ -39,28 +39,21 @@ def __init__(self, args):
3939 self ._herding_selection = args .get ("herding_selection" , "icarl" )
4040 self ._n_classes = 0
4141
42- self ._use_mimic_score = args .get ("mimic_score" , False )
4342 self ._less_forget_config = args .get ("less_forget" , {})
4443 assert isinstance (self ._less_forget_config , dict )
4544
4645 self ._lambda_schedule = args .get ("lambda_schedule" , False )
47- self ._ranking_loss = args .get ("ranking_loss" , {})
48-
49- self ._relative_teachers_config = args .get ("relative_teachers" , {})
5046
5147 self ._gor_config = args .get ("gor_config" , {})
5248
5349 self ._ams_config = args .get ("adaptative_margin_softmax" , {})
5450 self ._softmax_ce = args .get ("softmax_ce" , False )
5551
5652 self ._attention_residual_config = args .get ("attention_residual" , {})
57- assert isinstance (self ._attention_residual_config , dict ), "ra need to be dict"
5853
5954 self ._perceptual_features = args .get ("perceptual_features" )
6055 self ._perceptual_style = args .get ("perceptual_style" )
6156
62- self ._use_teacher_confidence = args .get ("teacher_confidence" , False )
63-
6457 self ._groupwise_factors = args .get ("groupwise_factors" , {})
6558 self ._softtriple_config = args .get ("softriple_regularizer" , {})
6659
@@ -78,24 +71,15 @@ def __init__(self, args):
7871 self ._evaluation_type = args .get ("eval_type" , "icarl" )
7972 self ._evaluation_config = args .get ("evaluation_config" , {})
8073
81- self ._weights_orthogonality = args .get ("weights_orthogonality" )
82- self ._orthoreg_config = args .get ("orthoreg_config" , {})
83- self ._dso_config = args .get ("dso_config" , {})
84- self ._mc_config = args .get ("mc_config" , {})
85- self ._srip_config = args .get ("srip_config" , {})
8674 self ._double_margin_reg = args .get ("double_margin_reg" , {})
8775
8876 self ._save_model = args ["save_model" ]
8977
90- self ._harmonic_embeddings = args .get ("harmonic_embeddings" , {})
91-
9278 self ._rotations_config = args .get ("rotations_config" , {})
9379
9480 self ._eval_every_x_epochs = args .get ("eval_every_x_epochs" )
9581 self ._early_stopping = args .get ("early_stopping" , {})
9682
97- self ._random_noise_config = args .get ("random_noise_config" , {})
98-
9983 classifier_kwargs = args .get ("classifier_config" , {})
10084 self ._network = network .BasicNet (
10185 args ["convnet" ],
@@ -127,12 +111,6 @@ def __init__(self, args):
127111 self ._herding_compressed_indexes = []
128112
129113 self ._weight_generation = args .get ("weight_generation" )
130- self ._compressed_memory = args .get ("compressed_memory" )
131- self ._alternate_training_config = args .get ("alternate_training" )
132-
133- self ._compressed_data = {}
134- self ._compressed_targets = {}
135- self ._compressed_means = []
136114
137115 self ._saved_network = None
138116 self ._post_processing_type = None
@@ -143,9 +121,7 @@ def __init__(self, args):
143121 @property
144122 def _memory_per_class (self ):
145123 """Returns the number of examplars per class."""
146- if self ._compressed_memory :
147- return self ._compressed_memory ["quantity_images" ]
148- elif self ._fixed_memory :
124+ if self ._fixed_memory :
149125 return self ._memory_size // self ._total_n_classes
150126 return self ._memory_size // self ._n_classes
151127
@@ -154,9 +130,6 @@ def _train_task(self, train_loader, val_loader):
154130 if p .requires_grad :
155131 p .register_hook (lambda grad : torch .clamp (grad , - 5. , 5. ))
156132
157- if self ._alternate_training_config and self ._task != 0 :
158- return self ._alternate_training (train_loader , val_loader )
159-
160133 logger .debug ("nb {}." .format (len (train_loader .dataset )))
161134 self ._training_step (train_loader , val_loader , 0 , self ._n_epochs )
162135
@@ -224,93 +197,10 @@ def weight_decay(self):
224197 )
225198 )
226199
227- def _alternate_training (self , train_loader , val_loader ):
228- for phase in self ._alternate_training_config :
229- if phase ["update_theta" ]:
230- logger .info ("Updating theta" )
231- for class_index in range (self ._n_classes - self ._task_size , self ._n_classes ):
232- _ , loader = self .inc_dataset .get_custom_loader ([class_index ])
233- features , _ = utils .extract_features (self ._network , loader )
234- features = F .normalize (torch .from_numpy (features ), p = 2 , dim = 1 )
235- mean = torch .mean (features , dim = 0 )
236- mean = F .normalize (mean , dim = 0 , p = 2 )
237-
238- self ._network .classifier .weights .data [class_index ] = mean .to (self ._device )
239-
240- self ._network .freeze (trainable = phase ["train_f" ], model = "convnet" )
241- self ._network .freeze (trainable = phase ["train_theta" ], model = "classifier" )
242- logger .info ("Freeze convnet=" + str (phase ["train_f" ]))
243- logger .info ("Freeze classifier=" + str (phase ["train_theta" ]))
244-
245- self ._optimizer = factory .get_optimizer (
246- self ._network .parameters (), self ._opt_name , self ._lr , self ._weight_decay
247- )
248- self ._training_step (train_loader , val_loader , 0 , phase ["nb_epochs" ])
249-
250200 def _after_task (self , inc_dataset ):
251201 self ._monitor_scale ()
252202 super ()._after_task (inc_dataset )
253203
254- if self ._compressed_memory :
255- self .add_compressed_memory ()
256-
257- def add_compressed_memory (self ):
258- _ , _ , self ._herding_compressed_indexes , _ = self .build_examplars (
259- self .inc_dataset , self ._herding_compressed_indexes , self .quantity_compressed_embeddings
260- )
261-
262- # Computing the embeddings of only the current task images:
263- for class_index in range (self ._n_classes - self ._task_size , self ._n_classes ):
264- _ , loader = self .inc_dataset .get_custom_loader ([class_index ])
265- features , targets = utils .extract_features (self ._network , loader )
266-
267- selected_features = features [self ._herding_compressed_indexes [class_index ]]
268- selected_targets = targets [self ._herding_compressed_indexes [class_index ]]
269-
270- self ._compressed_means .append (np .mean (selected_features , axis = 0 ))
271-
272- self ._compressed_data [class_index ] = selected_features
273- self ._compressed_targets [class_index ] = selected_targets
274-
275- logger .info (
276- "{} compressed memory, or {} per class." .format (
277- sum (len (x ) for x in self ._compressed_data .values ()),
278- self .quantity_compressed_embeddings
279- )
280- )
281-
282- # Taking in account the mean shift of the class:
283- if self ._compressed_memory ["mean_shift" ]:
284- logger .info ("Computing mean shift" )
285- for class_index in range (self ._n_classes - self ._task_size ):
286- class_memory , class_targets = utils .select_class_samples (
287- self ._data_memory , self ._targets_memory , class_index
288- )
289-
290- _ , loader = self .inc_dataset .get_custom_loader (
291- [], memory = ((class_memory , class_targets ))
292- )
293- features , _ = utils .extract_features (self ._network , loader )
294- features_mean = np .mean (features , axis = 0 )
295-
296- diff_mean = features_mean - self ._compressed_means [class_index ]
297-
298- self ._compressed_data [class_index ] += diff_mean
299-
300- for class_index in range (self ._n_classes ):
301- indexes = np .random .permutation (self .quantity_compressed_embeddings )
302- self ._compressed_data [class_index ] = self ._compressed_data [class_index ][indexes ]
303-
304- @property
305- def quantity_compressed_embeddings (self ):
306- assert self ._compressed_memory
307-
308- embed_size = 64 * 16
309- image_size = 32 * 32 * 3 * 8
310- total_mem = image_size * 20
311-
312- return (total_mem - image_size * self ._compressed_memory ["quantity_images" ]) // embed_size
313-
314204 def _monitor_scale (self ):
315205 if "scale" not in self ._args ["_logs" ]:
316206 self ._args ["_logs" ]["scale" ] = []
@@ -454,49 +344,16 @@ def _before_task(self, train_loader, val_loader):
454344 task = self ._task
455345 )
456346
457- if self ._compressed_memory :
458- self ._compressed_iterator = 0
459- self ._compressed_step = self .quantity_compressed_embeddings // len (train_loader )
460-
461347 if self ._class_weights_config :
462348 self ._class_weights = torch .tensor (
463349 data .get_class_weights (train_loader .dataset , ** self ._class_weights_config )
464350 ).to (self ._device )
465351 else :
466352 self ._class_weights = None
467353
468- def _sample_compressed (self ):
469- features , logits , targets = [], [], []
470-
471- low_index = self ._compressed_iterator * self ._compressed_step
472- self ._compressed_iterator += 1
473- high_index = self ._compressed_iterator * self ._compressed_step
474-
475- for class_index in self ._compressed_data .keys ():
476- f = self ._compressed_data [class_index ][low_index :high_index ]
477- t = self ._compressed_targets [class_index ][low_index :high_index ]
478-
479- f = torch .tensor (f ).to (self ._device )
480- t = torch .tensor (t ).to (self ._device )
481-
482- logits .append (self ._network .classifier (f ))
483- features .append (f )
484- targets .append (t )
485-
486- return torch .cat (features ), torch .cat (logits ), torch .cat (targets )
487-
488354 def _compute_loss (self , inputs , features_logits , targets , onehot_targets , memory_flags ):
489355 features , logits , atts = features_logits
490356
491- if self ._random_noise_config :
492- logits = logits [:- self ._random_noise_config ["nb_per_batch" ]]
493-
494- if self ._compressed_memory and len (self ._compressed_data ) > 0 :
495- c_f , c_l , c_t = self ._sample_compressed ()
496- features = torch .cat ((features , c_f ))
497- logits = torch .cat ((logits , c_l ))
498- targets = torch .cat ((targets , c_t ))
499-
500357 if self ._post_processing_type is None :
501358 scaled_logits = self ._network .post_process (logits )
502359 else :
@@ -506,13 +363,6 @@ def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory
506363 with torch .no_grad ():
507364 old_features , old_logits , old_atts = self ._old_model (inputs )
508365
509- if self ._compressed_memory and len (self ._compressed_data ) > 0 :
510- old_features = torch .cat ((old_features , c_f ))
511- old_logits = torch .cat ((old_logits , self ._old_model .classifier (c_f )))
512-
513- if self ._random_noise_config :
514- old_logits = old_logits [:- self ._random_noise_config ["nb_per_batch" ]]
515-
516366 if self ._ams_config :
517367 ams_config = copy .deepcopy (self ._ams_config )
518368 if self ._network .post_processor :
@@ -526,90 +376,21 @@ def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory
526376 ** ams_config
527377 )
528378 self ._metrics ["ams" ] += loss .item ()
529- elif self ._use_npair :
530- loss = losses .n_pair_loss (logits , targets )
531- self ._metrics ["npair" ] += loss .item ()
532- elif self ._proxy_nca_config :
533- if self ._network .post_processor :
534- self ._proxy_nca_config ["s" ] = self ._network .post_processor .factor
535-
536- loss = losses .proxy_nca_github (
537- scaled_logits , targets , self ._n_classes , ** self ._proxy_nca_config
538- )
539- self ._metrics ["nca" ] += loss .item ()
540- elif self ._triplet_config :
541- loss , percent_violated = losses .triplet_loss (
542- features ,
543- targets ,
544- ** self ._triplet_config ,
545- harmonic_embeddings = self ._harmonic_embeddings ,
546- old_features = old_features if self ._old_model else None ,
547- memory_flags = memory_flags ,
548- epoch_percent = self ._epoch_percent
549- )
550-
551- self ._metrics ["tri" ] += loss .item ()
552- self ._metrics ["violated" ] += percent_violated
553379 elif self ._softmax_ce :
554380 loss = F .cross_entropy (scaled_logits , targets )
555381 self ._metrics ["cce" ] += loss .item ()
556- else :
557- if self ._use_teacher_confidence and self ._old_model is not None :
558- loss = losses .cross_entropy_teacher_confidence (
559- scaled_logits , targets , F .softmax (old_logits , dim = 1 ), memory_flags
560- )
561- self ._metrics ["clf_conf" ] += loss .item ()
562- else :
563- loss = F .cross_entropy (scaled_logits , targets )
564- self ._metrics ["clf" ] += loss .item ()
565382
566383 # ----------------------
567384 # Regularization losses:
568385 # ----------------------
569386
570- if self ._weights_orthogonality is not None :
571- margin = self ._weights_orthogonality .get ("margin" )
572- ortho_loss = losses .weights_orthogonality (
573- self ._network .classifier .weights , margin = margin
574- )
575- loss += ortho_loss
576- self ._metrics ["ortho" ] += ortho_loss .item ()
577-
578387 if self ._gor_config :
579388 gor_loss = losses .global_orthogonal_regularization (
580389 features , targets , self ._n_classes - self ._task_size , ** self ._gor_config
581390 )
582391 self ._metrics ["gor" ] += gor_loss .item ()
583392 loss += gor_loss
584393
585- if self ._orthoreg_config :
586- orthoreg_loss = losses .ortho_reg (
587- self ._network .classifier .weights , self ._orthoreg_config
588- )
589- self ._metrics ["orthoreg" ] += orthoreg_loss .item ()
590- loss += orthoreg_loss
591-
592- if self ._dso_config :
593- dso_loss = losses .double_soft_orthoreg (
594- self ._network .classifier .weights , self ._dso_config
595- )
596- self ._metrics ["dso" ] += dso_loss .item ()
597- loss += dso_loss
598-
599- if self ._mc_config :
600- mc_loss = losses .mutual_coherence_regularization (
601- self ._network .classifier .weights , self ._mc_config
602- )
603- self ._metrics ["mc" ] += mc_loss .item ()
604- loss += mc_loss
605-
606- if self ._srip_config :
607- srip_loss = losses .spectral_restricted_isometry_property_regularization (
608- self ._network .classifier .weights , self ._srip_config
609- )
610- self ._metrics ["srip" ] += srip_loss .item ()
611- loss += srip_loss
612-
613394 if self ._softtriple_config :
614395 st_reg = losses .softriple_regularizer (
615396 self ._network .classifier .weights , self ._softtriple_config
@@ -644,41 +425,6 @@ def _compute_loss(self, inputs, features_logits, targets, onehot_targets, memory
644425 distil_loss = factor * losses .embeddings_similarity (old_features , features )
645426 loss += distil_loss
646427 self ._metrics ["lf" ] += distil_loss .item ()
647- elif self ._use_mimic_score :
648- old_class_logits = logits [..., :self ._n_classes - self ._task_size ]
649- old_class_old_logits = old_logits [..., :self ._n_classes - self ._task_size ]
650-
651- mimic_loss = F .mse_loss (old_class_logits , old_class_old_logits )
652- mimic_loss *= (self ._n_classes - self ._task_size )
653- loss += mimic_loss
654- self ._metrics ["mimic" ] += mimic_loss .item ()
655-
656- if self ._ranking_loss :
657- ranking_loss = self ._ranking_loss ["factor" ] * losses .ucir_ranking (
658- logits ,
659- targets ,
660- self ._n_classes ,
661- self ._task_size ,
662- nb_negatives = self ._ranking_loss ["nb_negatives" ],
663- margin = self ._ranking_loss ["margin" ]
664- )
665- loss += ranking_loss
666- self ._metrics ["rank" ] += ranking_loss .item ()
667-
668- if self ._relative_teachers_config :
669- if self ._relative_teachers_config ["select" ] == "old" :
670- indexes_old = memory_flags .eq (1. )
671- old_features_memory = old_features [indexes_old ]
672- new_features_memory = features [indexes_old ]
673- else :
674- old_features_memory = old_features
675- new_features_memory = features
676-
677- relative_t_loss = losses .relative_teacher_distances (
678- old_features_memory , new_features_memory , ** self ._relative_teachers_config
679- )
680- loss += self ._relative_teachers_config ["factor" ] * relative_t_loss
681- self ._metrics ["rel" ] += relative_t_loss .item ()
682428
683429 if self ._attention_residual_config :
684430 if self ._attention_residual_config .get ("scheduled_factor" , False ):
0 commit comments