Skip to content

Commit 5242ba6

Browse files
committed
MobileOne and FastViT weights on HF hub, more code cleanup and tweaks, features_only working. Add reparam flag to validate and benchmark, support reparm of all models with fuse(), reparameterize() or switch_to_deploy() methods on modules
1 parent 40dbaaf commit 5242ba6

File tree

8 files changed

+447
-304
lines changed

8 files changed

+447
-304
lines changed

benchmark.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from timm.layers import set_fast_norm
2323
from timm.models import create_model, is_model, list_models
2424
from timm.optim import create_optimizer_v2
25-
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
25+
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\
26+
reparameterize_model
2627

2728
has_apex = False
2829
try:
@@ -116,6 +117,8 @@
116117
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
117118
parser.add_argument('--fast-norm', default=False, action='store_true',
118119
help='enable experimental fast-norm')
120+
parser.add_argument('--reparam', default=False, action='store_true',
121+
help='Reparameterize model')
119122
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
120123

121124
# codegen (model compilation) options
@@ -222,6 +225,7 @@ def __init__(
222225
torchscript=False,
223226
torchcompile=None,
224227
aot_autograd=False,
228+
reparam=False,
225229
precision='float32',
226230
fuser='',
227231
num_warm_iter=10,
@@ -252,10 +256,13 @@ def __init__(
252256
drop_block_rate=kwargs.pop('drop_block', None),
253257
**kwargs.pop('model_kwargs', {}),
254258
)
259+
if reparam:
260+
self.model = reparameterize_model(self.model)
255261
self.model.to(
256262
device=self.device,
257263
dtype=self.model_dtype,
258-
memory_format=torch.channels_last if self.channels_last else None)
264+
memory_format=torch.channels_last if self.channels_last else None,
265+
)
259266
self.num_classes = self.model.num_classes
260267
self.param_count = count_params(self.model)
261268
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))

0 commit comments

Comments
 (0)