|
51 | 51 | FlopCountAnalysis = None |
52 | 52 | has_fvcore_profiling = False |
53 | 53 |
|
| 54 | +try: |
| 55 | + from functorch.compile import memory_efficient_fusion |
| 56 | + has_functorch = True |
| 57 | +except ImportError as e: |
| 58 | + has_functorch = False |
| 59 | + |
54 | 60 |
|
55 | 61 | torch.backends.cudnn.benchmark = True |
56 | 62 | _logger = logging.getLogger('validate') |
|
95 | 101 | help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') |
96 | 102 | parser.add_argument('--precision', default='float32', type=str, |
97 | 103 | help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') |
98 | | -parser.add_argument('--torchscript', dest='torchscript', action='store_true', |
99 | | - help='convert model torchscript for inference') |
100 | 104 | parser.add_argument('--fuser', default='', type=str, |
101 | 105 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") |
| 106 | +scripting_group = parser.add_mutually_exclusive_group() |
| 107 | +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', |
| 108 | + help='convert model torchscript for inference') |
| 109 | +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', |
| 110 | + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") |
102 | 111 |
|
103 | 112 |
|
104 | 113 | # train optimizer parameters |
@@ -188,7 +197,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False |
188 | 197 |
|
189 | 198 | class BenchmarkRunner: |
190 | 199 | def __init__( |
191 | | - self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', |
| 200 | + self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32', |
192 | 201 | fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): |
193 | 202 | self.model_name = model_name |
194 | 203 | self.detail = detail |
@@ -220,11 +229,14 @@ def __init__( |
220 | 229 | if torchscript: |
221 | 230 | self.model = torch.jit.script(self.model) |
222 | 231 | self.scripted = True |
223 | | - |
224 | 232 | data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) |
225 | 233 | self.input_size = data_config['input_size'] |
226 | 234 | self.batch_size = kwargs.pop('batch_size', 256) |
227 | 235 |
|
| 236 | + if aot_autograd: |
| 237 | + assert has_functorch, "functorch is needed for --aot-autograd" |
| 238 | + self.model = memory_efficient_fusion(self.model) |
| 239 | + |
228 | 240 | self.example_inputs = None |
229 | 241 | self.num_warm_iter = num_warm_iter |
230 | 242 | self.num_bench_iter = num_bench_iter |
|
0 commit comments