|
14 | 14 | from collections import OrderedDict |
15 | 15 |
|
16 | 16 | from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models |
17 | | -from timm.data import Dataset, create_loader, resolve_data_config |
| 17 | +from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config |
18 | 18 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging |
19 | 19 |
|
20 | 20 | torch.backends.cudnn.benchmark = True |
|
24 | 24 | help='path to dataset') |
25 | 25 | parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', |
26 | 26 | help='model architecture (default: dpn92)') |
27 | | -parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', |
| 27 | +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', |
28 | 28 | help='number of data loading workers (default: 2)') |
29 | 29 | parser.add_argument('-b', '--batch-size', default=256, type=int, |
30 | 30 | metavar='N', help='mini-batch size (default: 256)') |
@@ -91,9 +91,14 @@ def validate(args): |
91 | 91 |
|
92 | 92 | criterion = nn.CrossEntropyLoss().cuda() |
93 | 93 |
|
| 94 | + if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): |
| 95 | + dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing) |
| 96 | + else: |
| 97 | + dataset = Dataset(args.data, load_bytes=args.tf_preprocessing) |
| 98 | + |
94 | 99 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] |
95 | 100 | loader = create_loader( |
96 | | - Dataset(args.data, load_bytes=args.tf_preprocessing), |
| 101 | + dataset, |
97 | 102 | input_size=data_config['input_size'], |
98 | 103 | batch_size=args.batch_size, |
99 | 104 | use_prefetcher=args.prefetcher, |
|
0 commit comments