Skip to content

Commit f0f9ecc

Browse files
committed
Add --fuser arg to train/validate/benchmark scripts to select jit fuser type
1 parent 010b486 commit f0f9ecc

File tree

5 files changed

+49
-9
lines changed

5 files changed

+49
-9
lines changed

benchmark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from timm.models import create_model, is_model, list_models
2222
from timm.optim import create_optimizer_v2
2323
from timm.data import resolve_data_config
24-
from timm.utils import AverageMeter, setup_default_logging
24+
from timm.utils import setup_default_logging, set_jit_fuser
2525

2626

2727
has_apex = False
@@ -95,7 +95,8 @@
9595
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
9696
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
9797
help='convert model torchscript for inference')
98-
98+
parser.add_argument('--fuser', default='', type=str,
99+
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
99100

100101

101102
# train optimizer parameters
@@ -186,14 +187,16 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
186187
class BenchmarkRunner:
187188
def __init__(
188189
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
189-
num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
190+
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
190191
self.model_name = model_name
191192
self.detail = detail
192193
self.device = device
193194
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
194195
self.channels_last = kwargs.pop('channels_last', False)
195196
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
196197

198+
if fuser:
199+
set_jit_fuser(fuser)
197200
self.model = create_model(
198201
model_name,
199202
num_classes=kwargs.pop('num_classes', None),

timm/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .clip_grad import dispatch_clip_grad
44
from .cuda import ApexScaler, NativeScaler
55
from .distributed import distribute_bn, reduce_tensor
6-
from .jit import set_jit_legacy
6+
from .jit import set_jit_legacy, set_jit_fuser
77
from .log import setup_default_logging, FormatterNoInfo
88
from .metrics import AverageMeter, accuracy
99
from .misc import natural_key, add_bool_arg

timm/utils/jit.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
33
Hacked together by / Copyright 2020 Ross Wightman
44
"""
5+
import os
6+
57
import torch
68

79

@@ -16,3 +18,33 @@ def set_jit_legacy():
1618
torch._C._jit_set_profiling_mode(False)
1719
torch._C._jit_override_can_fuse_on_gpu(True)
1820
#torch._C._jit_set_texpr_fuser_enabled(True)
21+
22+
23+
def set_jit_fuser(fuser):
24+
if fuser == "te":
25+
# default fuser should be == 'te'
26+
torch._C._jit_set_profiling_executor(True)
27+
torch._C._jit_set_profiling_mode(True)
28+
torch._C._jit_override_can_fuse_on_cpu(False)
29+
torch._C._jit_override_can_fuse_on_gpu(True)
30+
torch._C._jit_set_texpr_fuser_enabled(True)
31+
elif fuser == "old" or fuser == "legacy":
32+
torch._C._jit_set_profiling_executor(False)
33+
torch._C._jit_set_profiling_mode(False)
34+
torch._C._jit_override_can_fuse_on_gpu(True)
35+
torch._C._jit_set_texpr_fuser_enabled(False)
36+
elif fuser == "nvfuser" or fuser == "nvf":
37+
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
38+
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
39+
os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
40+
torch._C._jit_set_texpr_fuser_enabled(False)
41+
torch._C._jit_set_profiling_executor(True)
42+
torch._C._jit_set_profiling_mode(True)
43+
torch._C._jit_can_fuse_on_cpu()
44+
torch._C._jit_can_fuse_on_gpu()
45+
torch._C._jit_override_can_fuse_on_cpu(False)
46+
torch._C._jit_override_can_fuse_on_gpu(False)
47+
torch._C._jit_set_nvfuser_guard_mode(True)
48+
torch._C._jit_set_nvfuser_enabled(True)
49+
else:
50+
assert False, f"Invalid jit fuser ({fuser})"

train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@
295295
help='use the multi-epochs-loader to save time at the beginning of every epoch')
296296
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
297297
help='convert model torchscript for inference')
298+
parser.add_argument('--fuser', default='', type=str,
299+
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
298300
parser.add_argument('--log-wandb', action='store_true', default=False,
299301
help='log training and validation metrics to wandb')
300302

@@ -364,6 +366,9 @@ def main():
364366

365367
random_seed(args.seed, args.rank)
366368

369+
if args.fuser:
370+
set_jit_fuser(args.fuser)
371+
367372
model = create_model(
368373
args.model,
369374
pretrained=args.pretrained,

validate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
2323
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
24-
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
24+
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser
2525

2626
has_apex = False
2727
try:
@@ -102,8 +102,8 @@
102102
help='use ema version of weights if present')
103103
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
104104
help='convert model torchscript for inference')
105-
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
106-
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
105+
parser.add_argument('--fuser', default='', type=str,
106+
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
107107
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
108108
help='Output csv file for validation results (summary)')
109109
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
@@ -133,8 +133,8 @@ def validate(args):
133133
else:
134134
_logger.info('Validating in float32. AMP not enabled.')
135135

136-
if args.legacy_jit:
137-
set_jit_legacy()
136+
if args.fuser:
137+
set_jit_fuser(args.fuser)
138138

139139
# create model
140140
model = create_model(

0 commit comments

Comments
 (0)