|
38 | 38 | except AttributeError: |
39 | 39 | pass |
40 | 40 |
|
| 41 | +try: |
| 42 | + from functorch.compile import memory_efficient_fusion |
| 43 | + has_functorch = True |
| 44 | +except ImportError as e: |
| 45 | + has_functorch = False |
| 46 | + |
41 | 47 | torch.backends.cudnn.benchmark = True |
42 | 48 | _logger = logging.getLogger('validate') |
43 | 49 |
|
|
101 | 107 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed') |
102 | 108 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', |
103 | 109 | help='use ema version of weights if present') |
104 | | -parser.add_argument('--torchscript', dest='torchscript', action='store_true', |
105 | | - help='convert model torchscript for inference') |
| 110 | +scripting_group = parser.add_mutually_exclusive_group() |
| 111 | +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', |
| 112 | + help='torch.jit.script the full model') |
| 113 | +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', |
| 114 | + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") |
106 | 115 | parser.add_argument('--fuser', default='', type=str, |
107 | 116 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") |
108 | 117 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', |
@@ -162,7 +171,10 @@ def validate(args): |
162 | 171 |
|
163 | 172 | if args.torchscript: |
164 | 173 | torch.jit.optimized_execution(True) |
165 | | - model = torch.jit.script(model) |
| 174 | + model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size'])) |
| 175 | + if args.aot_autograd: |
| 176 | + assert has_functorch, "functorch is needed for --aot-autograd" |
| 177 | + model = memory_efficient_fusion(model) |
166 | 178 |
|
167 | 179 | model = model.cuda() |
168 | 180 | if args.apex_amp: |
|
0 commit comments