@@ -19,8 +19,7 @@ class IncrementalLearner(abc.ABC):
1919 def __init__ (self , * args , ** kwargs ):
2020 pass
2121
22- def set_task_info (self , task , total_n_classes , increment , n_train_data , n_test_data ,
23- n_tasks ):
22+ def set_task_info (self , task , total_n_classes , increment , n_train_data , n_test_data , n_tasks ):
2423 self ._task = task
2524 self ._task_size = increment
2625 self ._total_n_classes = total_n_classes
@@ -60,7 +59,7 @@ def eval(self):
6059 def train (self ):
6160 raise NotImplementedError
6261
63- def _before_task (self , data_loader ):
62+ def _before_task (self , data_loader , val_loader ):
6463 pass
6564
6665 def _train_task (self , train_loader , val_loader ):
@@ -72,6 +71,12 @@ def _after_task(self, data_loader):
7271 def _eval_task (self , data_loader ):
7372 raise NotImplementedError
7473
74+ def save_metadata (self , path ):
75+ pass
76+
77+ def load_metadata (self , path ):
78+ pass
79+
7580 @property
7681 def _new_task_index (self ):
7782 return self ._task * self ._task_size
0 commit comments