File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -229,14 +229,14 @@ def __init__(
229229 if torchscript :
230230 self .model = torch .jit .script (self .model )
231231 self .scripted = True
232- if aot_autograd :
233- assert has_functorch , "functorch is needed for --aot-autograd"
234- self .model = memory_efficient_fusion (self .model )
235-
236232 data_config = resolve_data_config (kwargs , model = self .model , use_test_size = not use_train_size )
237233 self .input_size = data_config ['input_size' ]
238234 self .batch_size = kwargs .pop ('batch_size' , 256 )
239235
236+ if aot_autograd :
237+ assert has_functorch , "functorch is needed for --aot-autograd"
238+ self .model = memory_efficient_fusion (self .model )
239+
240240 self .example_inputs = None
241241 self .num_warm_iter = num_warm_iter
242242 self .num_bench_iter = num_bench_iter
You can’t perform that action at this time.
0 commit comments