66Hacked together by Ross Wightman (https://github.com/rwightman)
77"""
88import argparse
9- import os
109import csv
1110import json
12- import time
1311import logging
14- import torch
15- import torch .nn as nn
16- import torch .nn .parallel
12+ import time
1713from collections import OrderedDict
1814from contextlib import suppress
1915from functools import partial
2016
17+ import torch
18+ import torch .nn as nn
19+ import torch .nn .parallel
20+
21+ from timm .data import resolve_data_config
2122from timm .models import create_model , is_model , list_models
2223from timm .optim import create_optimizer_v2
23- from timm .data import resolve_data_config
2424from timm .utils import setup_default_logging , set_jit_fuser
2525
26-
2726has_apex = False
2827try :
2928 from apex import amp
7170 help = "Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'" )
7271parser .add_argument ('--detail' , action = 'store_true' , default = False ,
7372 help = 'Provide train fwd/bwd/opt breakdown detail if True. Defaults to False' )
73+ parser .add_argument ('--no-retry' , action = 'store_true' , default = False ,
74+ help = 'Do not decay batch size and retry on error.' )
7475parser .add_argument ('--results-file' , default = '' , type = str , metavar = 'FILENAME' ,
7576 help = 'Output csv file for validation results (summary)' )
7677parser .add_argument ('--num-warm-iter' , default = 10 , type = int ,
@@ -169,10 +170,9 @@ def resolve_precision(precision: str):
169170
170171
171172def profile_deepspeed (model , input_size = (3 , 224 , 224 ), batch_size = 1 , detailed = False ):
172- macs , _ = get_model_profile (
173+ _ , macs , _ = get_model_profile (
173174 model = model ,
174- input_res = (batch_size ,) + input_size , # input shape or input to the input_constructor
175- input_constructor = None , # if specified, a constructor taking input_res is used as input to the model
175+ input_shape = (batch_size ,) + input_size , # input shape/resolution
176176 print_profile = detailed , # prints the model graph with the measured profile attached to each module
177177 detailed = detailed , # print the detailed profile
178178 warm_up = 10 , # the number of warm-ups before measuring the time of each module
@@ -197,8 +197,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
197197
198198class BenchmarkRunner :
199199 def __init__ (
200- self , model_name , detail = False , device = 'cuda' , torchscript = False , aot_autograd = False , precision = 'float32' ,
201- fuser = '' , num_warm_iter = 10 , num_bench_iter = 50 , use_train_size = False , ** kwargs ):
200+ self ,
201+ model_name ,
202+ detail = False ,
203+ device = 'cuda' ,
204+ torchscript = False ,
205+ aot_autograd = False ,
206+ precision = 'float32' ,
207+ fuser = '' ,
208+ num_warm_iter = 10 ,
209+ num_bench_iter = 50 ,
210+ use_train_size = False ,
211+ ** kwargs
212+ ):
202213 self .model_name = model_name
203214 self .detail = detail
204215 self .device = device
@@ -225,11 +236,12 @@ def __init__(
225236 self .num_classes = self .model .num_classes
226237 self .param_count = count_params (self .model )
227238 _logger .info ('Model %s created, param count: %d' % (model_name , self .param_count ))
239+
240+ data_config = resolve_data_config (kwargs , model = self .model , use_test_size = not use_train_size )
228241 self .scripted = False
229242 if torchscript :
230243 self .model = torch .jit .script (self .model )
231244 self .scripted = True
232- data_config = resolve_data_config (kwargs , model = self .model , use_test_size = not use_train_size )
233245 self .input_size = data_config ['input_size' ]
234246 self .batch_size = kwargs .pop ('batch_size' , 256 )
235247
@@ -255,7 +267,13 @@ def _init_input(self):
255267
256268class InferenceBenchmarkRunner (BenchmarkRunner ):
257269
258- def __init__ (self , model_name , device = 'cuda' , torchscript = False , ** kwargs ):
270+ def __init__ (
271+ self ,
272+ model_name ,
273+ device = 'cuda' ,
274+ torchscript = False ,
275+ ** kwargs
276+ ):
259277 super ().__init__ (model_name = model_name , device = device , torchscript = torchscript , ** kwargs )
260278 self .model .eval ()
261279
@@ -324,7 +342,13 @@ def _step():
324342
325343class TrainBenchmarkRunner (BenchmarkRunner ):
326344
327- def __init__ (self , model_name , device = 'cuda' , torchscript = False , ** kwargs ):
345+ def __init__ (
346+ self ,
347+ model_name ,
348+ device = 'cuda' ,
349+ torchscript = False ,
350+ ** kwargs
351+ ):
328352 super ().__init__ (model_name = model_name , device = device , torchscript = torchscript , ** kwargs )
329353 self .model .train ()
330354
@@ -491,7 +515,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
491515 return max (0 , int (out_batch_size ))
492516
493517
494- def _try_run (model_name , bench_fn , initial_batch_size , bench_kwargs ):
518+ def _try_run (model_name , bench_fn , bench_kwargs , initial_batch_size , no_batch_size_retry = False ):
495519 batch_size = initial_batch_size
496520 results = dict ()
497521 error_str = 'Unknown'
@@ -506,8 +530,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
506530 if 'channels_last' in error_str :
507531 _logger .error (f'{ model_name } not supported in channels_last, skipping.' )
508532 break
509- _logger .warning (f'"{ error_str } " while running benchmark. Reducing batch size to { batch_size } for retry.' )
533+ _logger .error (f'"{ error_str } " while running benchmark.' )
534+ if no_batch_size_retry :
535+ break
510536 batch_size = decay_batch_exp (batch_size )
537+ _logger .warning (f'Reducing batch size to { batch_size } for retry.' )
511538 results ['error' ] = error_str
512539 return results
513540
@@ -549,7 +576,13 @@ def benchmark(args):
549576
550577 model_results = OrderedDict (model = model )
551578 for prefix , bench_fn in zip (prefixes , bench_fns ):
552- run_results = _try_run (model , bench_fn , initial_batch_size = batch_size , bench_kwargs = bench_kwargs )
579+ run_results = _try_run (
580+ model ,
581+ bench_fn ,
582+ bench_kwargs = bench_kwargs ,
583+ initial_batch_size = batch_size ,
584+ no_batch_size_retry = args .no_retry ,
585+ )
553586 if prefix and 'error' not in run_results :
554587 run_results = {'_' .join ([prefix , k ]): v for k , v in run_results .items ()}
555588 model_results .update (run_results )
0 commit comments