Skip to content

Commit 1ccce50

Browse files
authored
Merge pull request #1327 from rwightman/edgenext_csp_and_more
EdgeNeXt, additional DarkNets, and more
2 parents 2456223 + 1c5cb81 commit 1ccce50

29 files changed

+2496
-404
lines changed

benchmark.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,23 @@
66
Hacked together by Ross Wightman (https://github.com/rwightman)
77
"""
88
import argparse
9-
import os
109
import csv
1110
import json
12-
import time
1311
import logging
14-
import torch
15-
import torch.nn as nn
16-
import torch.nn.parallel
12+
import time
1713
from collections import OrderedDict
1814
from contextlib import suppress
1915
from functools import partial
2016

17+
import torch
18+
import torch.nn as nn
19+
import torch.nn.parallel
20+
21+
from timm.data import resolve_data_config
2122
from timm.models import create_model, is_model, list_models
2223
from timm.optim import create_optimizer_v2
23-
from timm.data import resolve_data_config
2424
from timm.utils import setup_default_logging, set_jit_fuser
2525

26-
2726
has_apex = False
2827
try:
2928
from apex import amp
@@ -71,6 +70,8 @@
7170
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
7271
parser.add_argument('--detail', action='store_true', default=False,
7372
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
73+
parser.add_argument('--no-retry', action='store_true', default=False,
74+
help='Do not decay batch size and retry on error.')
7475
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
7576
help='Output csv file for validation results (summary)')
7677
parser.add_argument('--num-warm-iter', default=10, type=int,
@@ -169,10 +170,9 @@ def resolve_precision(precision: str):
169170

170171

171172
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
172-
macs, _ = get_model_profile(
173+
_, macs, _ = get_model_profile(
173174
model=model,
174-
input_res=(batch_size,) + input_size, # input shape or input to the input_constructor
175-
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
175+
input_shape=(batch_size,) + input_size, # input shape/resolution
176176
print_profile=detailed, # prints the model graph with the measured profile attached to each module
177177
detailed=detailed, # print the detailed profile
178178
warm_up=10, # the number of warm-ups before measuring the time of each module
@@ -197,8 +197,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
197197

198198
class BenchmarkRunner:
199199
def __init__(
200-
self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32',
201-
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
200+
self,
201+
model_name,
202+
detail=False,
203+
device='cuda',
204+
torchscript=False,
205+
aot_autograd=False,
206+
precision='float32',
207+
fuser='',
208+
num_warm_iter=10,
209+
num_bench_iter=50,
210+
use_train_size=False,
211+
**kwargs
212+
):
202213
self.model_name = model_name
203214
self.detail = detail
204215
self.device = device
@@ -225,11 +236,12 @@ def __init__(
225236
self.num_classes = self.model.num_classes
226237
self.param_count = count_params(self.model)
227238
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
239+
240+
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
228241
self.scripted = False
229242
if torchscript:
230243
self.model = torch.jit.script(self.model)
231244
self.scripted = True
232-
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
233245
self.input_size = data_config['input_size']
234246
self.batch_size = kwargs.pop('batch_size', 256)
235247

@@ -255,7 +267,13 @@ def _init_input(self):
255267

256268
class InferenceBenchmarkRunner(BenchmarkRunner):
257269

258-
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
270+
def __init__(
271+
self,
272+
model_name,
273+
device='cuda',
274+
torchscript=False,
275+
**kwargs
276+
):
259277
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
260278
self.model.eval()
261279

@@ -324,7 +342,13 @@ def _step():
324342

325343
class TrainBenchmarkRunner(BenchmarkRunner):
326344

327-
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
345+
def __init__(
346+
self,
347+
model_name,
348+
device='cuda',
349+
torchscript=False,
350+
**kwargs
351+
):
328352
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
329353
self.model.train()
330354

@@ -491,7 +515,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
491515
return max(0, int(out_batch_size))
492516

493517

494-
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
518+
def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
495519
batch_size = initial_batch_size
496520
results = dict()
497521
error_str = 'Unknown'
@@ -506,8 +530,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
506530
if 'channels_last' in error_str:
507531
_logger.error(f'{model_name} not supported in channels_last, skipping.')
508532
break
509-
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
533+
_logger.error(f'"{error_str}" while running benchmark.')
534+
if no_batch_size_retry:
535+
break
510536
batch_size = decay_batch_exp(batch_size)
537+
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
511538
results['error'] = error_str
512539
return results
513540

@@ -549,7 +576,13 @@ def benchmark(args):
549576

550577
model_results = OrderedDict(model=model)
551578
for prefix, bench_fn in zip(prefixes, bench_fns):
552-
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
579+
run_results = _try_run(
580+
model,
581+
bench_fn,
582+
bench_kwargs=bench_kwargs,
583+
initial_batch_size=batch_size,
584+
no_batch_size_retry=args.no_retry,
585+
)
553586
if prefix and 'error' not in run_results:
554587
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
555588
model_results.update(run_results)

timm/data/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from .dataset_factory import create_dataset
77
from .loader import create_loader
88
from .mixup import Mixup, FastCollateMixup
9-
from .parsers import create_parser
9+
from .parsers import create_parser,\
10+
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
1011
from .real_labels import RealLabelsImagenet
1112
from .transforms import *
12-
from .transforms_factory import create_transform
13+
from .transforms_factory import create_transform

timm/data/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,15 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
6464
new_config['std'] = default_cfg['std']
6565

6666
# resolve default crop percentage
67-
new_config['crop_pct'] = DEFAULT_CROP_PCT
67+
crop_pct = DEFAULT_CROP_PCT
6868
if 'crop_pct' in args and args['crop_pct'] is not None:
69-
new_config['crop_pct'] = args['crop_pct']
70-
elif 'crop_pct' in default_cfg:
71-
new_config['crop_pct'] = default_cfg['crop_pct']
69+
crop_pct = args['crop_pct']
70+
else:
71+
if use_test_size and 'test_crop_pct' in default_cfg:
72+
crop_pct = default_cfg['test_crop_pct']
73+
elif 'crop_pct' in default_cfg:
74+
crop_pct = default_cfg['crop_pct']
75+
new_config['crop_pct'] = crop_pct
7276

7377
if verbose:
7478
_logger.info('Data processing configuration for current model + dataset:')

timm/data/dataset_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
kmnist=KMNIST,
2727
fashion_mnist=FashionMNIST,
2828
)
29-
_TRAIN_SYNONYM = {'train', 'training'}
30-
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
29+
_TRAIN_SYNONYM = dict(train=None, training=None)
30+
_EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None)
3131

3232

3333
def _search_split(root, split):

timm/data/parsers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .parser_factory import create_parser
2+
from .img_extensions import *

timm/data/parsers/constants.py

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from copy import deepcopy
2+
3+
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
4+
5+
6+
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
7+
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
8+
9+
10+
def _set_extensions(extensions):
11+
global IMG_EXTENSIONS
12+
global _IMG_EXTENSIONS_SET
13+
dedupe = set() # NOTE de-duping tuple while keeping original order
14+
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
15+
_IMG_EXTENSIONS_SET = set(extensions)
16+
17+
18+
def _valid_extension(x: str):
19+
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
20+
21+
22+
def is_img_extension(ext):
23+
return ext in _IMG_EXTENSIONS_SET
24+
25+
26+
def get_img_extensions(as_set=False):
27+
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
28+
29+
30+
def set_img_extensions(extensions):
31+
assert len(extensions)
32+
for x in extensions:
33+
assert _valid_extension(x)
34+
_set_extensions(extensions)
35+
36+
37+
def add_img_extensions(ext):
38+
if not isinstance(ext, (list, tuple, set)):
39+
ext = (ext,)
40+
for x in ext:
41+
assert _valid_extension(x)
42+
extensions = IMG_EXTENSIONS + tuple(ext)
43+
_set_extensions(extensions)
44+
45+
46+
def del_img_extensions(ext):
47+
if not isinstance(ext, (list, tuple, set)):
48+
ext = (ext,)
49+
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
50+
_set_extensions(extensions)

timm/data/parsers/parser_factory.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22

33
from .parser_image_folder import ParserImageFolder
4-
from .parser_image_tar import ParserImageTar
54
from .parser_image_in_tar import ParserImageInTar
65

76

timm/data/parsers/parser_image_folder.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,35 @@
66
Hacked together by / Copyright 2020 Ross Wightman
77
"""
88
import os
9+
from typing import Dict, List, Optional, Set, Tuple, Union
910

1011
from timm.utils.misc import natural_key
1112

12-
from .parser import Parser
1313
from .class_map import load_class_map
14-
from .constants import IMG_EXTENSIONS
14+
from .img_extensions import get_img_extensions
15+
from .parser import Parser
16+
17+
18+
def find_images_and_targets(
19+
folder: str,
20+
types: Optional[Union[List, Tuple, Set]] = None,
21+
class_to_idx: Optional[Dict] = None,
22+
leaf_name_only: bool = True,
23+
sort: bool = True
24+
):
25+
""" Walk folder recursively to discover images and map them to classes by folder names.
1526
27+
Args:
28+
folder: root of folder to recrusively search
29+
types: types (file extensions) to search for in path
30+
class_to_idx: specify mapping for class (folder name) to class index if set
31+
leaf_name_only: use only leaf-name of folder walk for class names
32+
sort: re-sort found images by name (for consistent ordering)
1633
17-
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
34+
Returns:
35+
A list of image and target tuples, class_to_idx mapping
36+
"""
37+
types = get_img_extensions(as_set=True) if not types else set(types)
1838
labels = []
1939
filenames = []
2040
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
@@ -51,7 +71,8 @@ def __init__(
5171
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
5272
if len(self.samples) == 0:
5373
raise RuntimeError(
54-
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
74+
f'Found 0 images in subfolders of {root}. '
75+
f'Supported image extensions are {", ".join(get_img_extensions())}')
5576

5677
def __getitem__(self, index):
5778
path, target = self.samples[index]

timm/data/parsers/parser_image_in_tar.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@
99
1010
Hacked together by / Copyright 2020 Ross Wightman
1111
"""
12+
import logging
1213
import os
13-
import tarfile
1414
import pickle
15-
import logging
16-
import numpy as np
15+
import tarfile
1716
from glob import glob
18-
from typing import List, Dict
17+
from typing import List, Tuple, Dict, Set, Optional, Union
18+
19+
import numpy as np
1920

2021
from timm.utils.misc import natural_key
2122

22-
from .parser import Parser
2323
from .class_map import load_class_map
24-
from .constants import IMG_EXTENSIONS
25-
24+
from .img_extensions import get_img_extensions
25+
from .parser import Parser
2626

2727
_logger = logging.getLogger(__name__)
2828
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
@@ -39,7 +39,7 @@ def reset(self):
3939
self.tf = None
4040

4141

42-
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
42+
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
4343
sample_count = 0
4444
for i, ti in enumerate(tf):
4545
if not ti.isfile():
@@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
6060
return sample_count
6161

6262

63-
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
63+
def extract_tarinfos(
64+
root,
65+
class_name_to_idx: Optional[Dict] = None,
66+
cache_tarinfo: Optional[bool] = None,
67+
extensions: Optional[Union[List, Tuple, Set]] = None,
68+
sort: bool = True
69+
):
70+
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
6471
root_is_tar = False
6572
if os.path.isfile(root):
6673
assert os.path.splitext(root)[-1].lower() == '.tar'
@@ -176,8 +183,8 @@ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
176183
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
177184
self.root,
178185
class_name_to_idx=class_name_to_idx,
179-
cache_tarinfo=cache_tarinfo,
180-
extensions=IMG_EXTENSIONS)
186+
cache_tarinfo=cache_tarinfo
187+
)
181188
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
182189
if len(tarfiles) == 1 and tarfiles[0][0] is None:
183190
self.root_is_tar = True

0 commit comments

Comments
 (0)