Skip to content

Commit 41a0703

Browse files
[lib] Update factory and __init__s.
1 parent de5096a commit 41a0703

File tree

3 files changed

+28
-26
lines changed

3 files changed

+28
-26
lines changed

inclearn/lib/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# flake8: noqa
22
from . import (
3-
calibration, callbacks, data, factory, herding, losses, metrics, network, pooling,
4-
results_utils, schedulers, utils
3+
calibration, callbacks, data, factory, herding, loops, losses, metrics, network, pooling,
4+
results_utils, schedulers, utils, vizualization
55
)

inclearn/lib/factory.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def get_optimizer(params, optimizer, lr, weight_decay=0.0):
1616
return optim.AdamW(params, lr=lr, weight_decay=weight_decay)
1717
elif optimizer == "sgd":
1818
return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)
19+
elif optimizer == "sgd_nesterov":
20+
return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9, nesterov=True)
1921

2022
raise NotImplementedError
2123

@@ -44,30 +46,27 @@ def get_convnet(convnet_type, **kwargs):
4446

4547

4648
def get_model(args):
47-
if args["model"] == "icarl":
48-
return models.ICarl(args)
49-
elif args["model"] == "lwf":
50-
return models.LwF(args)
51-
elif args["model"] == "e2e":
52-
return models.End2End(args)
53-
elif args["model"] == "medic":
54-
return models.Medic(args)
55-
elif args["model"] == "focusforget":
56-
return models.FocusForget(args)
57-
elif args["model"] == "fixed":
58-
return models.FixedRepresentation(args)
59-
elif args["model"] == "bic":
60-
return models.BiC(args)
61-
elif args["model"] == "icarlmixup":
62-
return models.ICarlMixUp(args)
63-
elif args["model"] == "ucir":
64-
return models.UCIR(args)
65-
elif args["model"] == "test":
66-
return models.Test(args)
67-
elif args["model"] == "still":
68-
return models.STILL(args)
69-
70-
raise NotImplementedError("Unknown model {}.".format(args["model"]))
49+
dict_models = {
50+
"icarl": models.ICarl,
51+
#"lwf": models.LwF,
52+
"e2e": models.End2End,
53+
#"medic": models.Medic,
54+
#"fixed": models.FixedRepresentation,
55+
"oracle": None,
56+
"bic": models.BiC,
57+
"ucir": models.UCIR,
58+
"still": models.STILL,
59+
"lwm": models.LwM
60+
}
61+
62+
model = args["model"].lower()
63+
64+
if model not in dict_models:
65+
raise NotImplementedError(
66+
"Unknown model {}, must be among {}.".format(args["model"], list(dict_models.keys()))
67+
)
68+
69+
return dict_models[model](args)
7170

7271

7372
def get_data(args, class_order=None):

inclearn/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
# flake8: noqa
2+
from .base import IncrementalLearner
13
from .bic import BiC
24
from .e2e import End2End
35
from .icarl import ICarl
6+
from .lwm import LwM
47
from .still import STILL
58
from .test import Test
69
from .ucir import UCIR

0 commit comments

Comments
 (0)