Skip to content

Commit 500c190

Browse files
committed
Add --aot-autograd (functorch efficient mem fusion) support to validate.py
1 parent 28e0152 commit 500c190

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

validate.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@
3838
except AttributeError:
3939
pass
4040

41+
try:
42+
from functorch.compile import memory_efficient_fusion
43+
has_functorch = True
44+
except ImportError as e:
45+
has_functorch = False
46+
4147
torch.backends.cudnn.benchmark = True
4248
_logger = logging.getLogger('validate')
4349

@@ -101,8 +107,11 @@
101107
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
102108
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
103109
help='use ema version of weights if present')
104-
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
105-
help='convert model torchscript for inference')
110+
scripting_group = parser.add_mutually_exclusive_group()
111+
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
112+
help='torch.jit.script the full model')
113+
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
114+
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
106115
parser.add_argument('--fuser', default='', type=str,
107116
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
108117
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
@@ -162,7 +171,10 @@ def validate(args):
162171

163172
if args.torchscript:
164173
torch.jit.optimized_execution(True)
165-
model = torch.jit.script(model)
174+
model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size']))
175+
if args.aot_autograd:
176+
assert has_functorch, "functorch is needed for --aot-autograd"
177+
model = memory_efficient_fusion(model)
166178

167179
model = model.cuda()
168180
if args.apex_amp:

0 commit comments

Comments
 (0)