Skip to content

Commit db8e33c

Browse files
authored
Merge pull request #1294 from xwang233/add-aot-autograd
Add AOT Autograd support
2 parents e4360e6 + 2d7ab06 commit db8e33c

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

benchmark.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@
5151
FlopCountAnalysis = None
5252
has_fvcore_profiling = False
5353

54+
try:
55+
from functorch.compile import memory_efficient_fusion
56+
has_functorch = True
57+
except ImportError as e:
58+
has_functorch = False
59+
5460

5561
torch.backends.cudnn.benchmark = True
5662
_logger = logging.getLogger('validate')
@@ -95,10 +101,13 @@
95101
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
96102
parser.add_argument('--precision', default='float32', type=str,
97103
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
98-
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
99-
help='convert model torchscript for inference')
100104
parser.add_argument('--fuser', default='', type=str,
101105
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
106+
scripting_group = parser.add_mutually_exclusive_group()
107+
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
108+
help='convert model torchscript for inference')
109+
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
110+
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
102111

103112

104113
# train optimizer parameters
@@ -188,7 +197,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
188197

189198
class BenchmarkRunner:
190199
def __init__(
191-
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
200+
self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32',
192201
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
193202
self.model_name = model_name
194203
self.detail = detail
@@ -220,11 +229,14 @@ def __init__(
220229
if torchscript:
221230
self.model = torch.jit.script(self.model)
222231
self.scripted = True
223-
224232
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
225233
self.input_size = data_config['input_size']
226234
self.batch_size = kwargs.pop('batch_size', 256)
227235

236+
if aot_autograd:
237+
assert has_functorch, "functorch is needed for --aot-autograd"
238+
self.model = memory_efficient_fusion(self.model)
239+
228240
self.example_inputs = None
229241
self.num_warm_iter = num_warm_iter
230242
self.num_bench_iter = num_bench_iter

train.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@
6161
except ImportError:
6262
has_wandb = False
6363

64+
try:
65+
from functorch.compile import memory_efficient_fusion
66+
has_functorch = True
67+
except ImportError as e:
68+
has_functorch = False
69+
70+
6471
torch.backends.cudnn.benchmark = True
6572
_logger = logging.getLogger('train')
6673

@@ -123,8 +130,11 @@
123130
help='Validation batch size override (default: None)')
124131
group.add_argument('--channels-last', action='store_true', default=False,
125132
help='Use channels_last memory layout')
126-
group.add_argument('--torchscript', dest='torchscript', action='store_true',
133+
scripting_group = group.add_mutually_exclusive_group()
134+
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
127135
help='torch.jit.script the full model')
136+
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
137+
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
128138
group.add_argument('--fuser', default='', type=str,
129139
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
130140
group.add_argument('--grad-checkpointing', action='store_true', default=False,
@@ -445,6 +455,9 @@ def main():
445455
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
446456
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
447457
model = torch.jit.script(model)
458+
if args.aot_autograd:
459+
assert has_functorch, "functorch is needed for --aot-autograd"
460+
model = memory_efficient_fusion(model)
448461

449462
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
450463

0 commit comments

Comments
 (0)