|
21 | 21 | from timm.models import create_model, is_model, list_models |
22 | 22 | from timm.optim import create_optimizer_v2 |
23 | 23 | from timm.data import resolve_data_config |
24 | | -from timm.utils import AverageMeter, setup_default_logging |
| 24 | +from timm.utils import setup_default_logging, set_jit_fuser |
25 | 25 |
|
26 | 26 |
|
27 | 27 | has_apex = False |
|
95 | 95 | help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') |
96 | 96 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', |
97 | 97 | help='convert model torchscript for inference') |
98 | | - |
| 98 | +parser.add_argument('--fuser', default='', type=str, |
| 99 | + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") |
99 | 100 |
|
100 | 101 |
|
101 | 102 | # train optimizer parameters |
@@ -186,14 +187,16 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False |
186 | 187 | class BenchmarkRunner: |
187 | 188 | def __init__( |
188 | 189 | self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', |
189 | | - num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): |
| 190 | + fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): |
190 | 191 | self.model_name = model_name |
191 | 192 | self.detail = detail |
192 | 193 | self.device = device |
193 | 194 | self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) |
194 | 195 | self.channels_last = kwargs.pop('channels_last', False) |
195 | 196 | self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress |
196 | 197 |
|
| 198 | + if fuser: |
| 199 | + set_jit_fuser(fuser) |
197 | 200 | self.model = create_model( |
198 | 201 | model_name, |
199 | 202 | num_classes=kwargs.pop('num_classes', None), |
|
0 commit comments