@@ -16,6 +16,8 @@ def get_optimizer(params, optimizer, lr, weight_decay=0.0):
1616 return optim .AdamW (params , lr = lr , weight_decay = weight_decay )
1717 elif optimizer == "sgd" :
1818 return optim .SGD (params , lr = lr , weight_decay = weight_decay , momentum = 0.9 )
19+ elif optimizer == "sgd_nesterov" :
20+ return optim .SGD (params , lr = lr , weight_decay = weight_decay , momentum = 0.9 , nesterov = True )
1921
2022 raise NotImplementedError
2123
@@ -44,30 +46,27 @@ def get_convnet(convnet_type, **kwargs):
4446
4547
4648def get_model (args ):
47- if args ["model" ] == "icarl" :
48- return models .ICarl (args )
49- elif args ["model" ] == "lwf" :
50- return models .LwF (args )
51- elif args ["model" ] == "e2e" :
52- return models .End2End (args )
53- elif args ["model" ] == "medic" :
54- return models .Medic (args )
55- elif args ["model" ] == "focusforget" :
56- return models .FocusForget (args )
57- elif args ["model" ] == "fixed" :
58- return models .FixedRepresentation (args )
59- elif args ["model" ] == "bic" :
60- return models .BiC (args )
61- elif args ["model" ] == "icarlmixup" :
62- return models .ICarlMixUp (args )
63- elif args ["model" ] == "ucir" :
64- return models .UCIR (args )
65- elif args ["model" ] == "test" :
66- return models .Test (args )
67- elif args ["model" ] == "still" :
68- return models .STILL (args )
69-
70- raise NotImplementedError ("Unknown model {}." .format (args ["model" ]))
49+ dict_models = {
50+ "icarl" : models .ICarl ,
51+ #"lwf": models.LwF,
52+ "e2e" : models .End2End ,
53+ #"medic": models.Medic,
54+ #"fixed": models.FixedRepresentation,
55+ "oracle" : None ,
56+ "bic" : models .BiC ,
57+ "ucir" : models .UCIR ,
58+ "still" : models .STILL ,
59+ "lwm" : models .LwM
60+ }
61+
62+ model = args ["model" ].lower ()
63+
64+ if model not in dict_models :
65+ raise NotImplementedError (
66+ "Unknown model {}, must be among {}." .format (args ["model" ], list (dict_models .keys ()))
67+ )
68+
69+ return dict_models [model ](args )
7170
7271
7372def get_data (args , class_order = None ):
0 commit comments