|
22 | 22 | from timm.layers import set_fast_norm |
23 | 23 | from timm.models import create_model, is_model, list_models |
24 | 24 | from timm.optim import create_optimizer_v2 |
25 | | -from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs |
| 25 | +from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs,\ |
| 26 | + reparameterize_model |
26 | 27 |
|
27 | 28 | has_apex = False |
28 | 29 | try: |
|
116 | 117 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") |
117 | 118 | parser.add_argument('--fast-norm', default=False, action='store_true', |
118 | 119 | help='enable experimental fast-norm') |
| 120 | +parser.add_argument('--reparam', default=False, action='store_true', |
| 121 | + help='Reparameterize model') |
119 | 122 | parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) |
120 | 123 |
|
121 | 124 | # codegen (model compilation) options |
@@ -222,6 +225,7 @@ def __init__( |
222 | 225 | torchscript=False, |
223 | 226 | torchcompile=None, |
224 | 227 | aot_autograd=False, |
| 228 | + reparam=False, |
225 | 229 | precision='float32', |
226 | 230 | fuser='', |
227 | 231 | num_warm_iter=10, |
@@ -252,10 +256,13 @@ def __init__( |
252 | 256 | drop_block_rate=kwargs.pop('drop_block', None), |
253 | 257 | **kwargs.pop('model_kwargs', {}), |
254 | 258 | ) |
| 259 | + if reparam: |
| 260 | + self.model = reparameterize_model(self.model) |
255 | 261 | self.model.to( |
256 | 262 | device=self.device, |
257 | 263 | dtype=self.model_dtype, |
258 | | - memory_format=torch.channels_last if self.channels_last else None) |
| 264 | + memory_format=torch.channels_last if self.channels_last else None, |
| 265 | + ) |
259 | 266 | self.num_classes = self.model.num_classes |
260 | 267 | self.param_count = count_params(self.model) |
261 | 268 | _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) |
|
0 commit comments