Skip to content

Commit 7c7ecd2

Browse files
committed
Add --use-train-size flag to force use of train input_size (over test input size) for validation. Default test-time pooling to use train input size (fixes issues).
1 parent ce65a7b commit 7c7ecd2

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

timm/models/layers/test_time_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x):
3636
return x.view(x.size(0), -1)
3737

3838

39-
def apply_test_time_pool(model, config, use_test_size=True):
39+
def apply_test_time_pool(model, config, use_test_size=False):
4040
test_time_pool = False
4141
if not hasattr(model, 'default_cfg') or not model.default_cfg:
4242
return model, False

validate.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
metavar='N', help='Input image dimension, uses model default if empty')
6868
parser.add_argument('--input-size', default=None, nargs=3, type=int,
6969
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
70+
parser.add_argument('--use-train-size', action='store_true', default=False,
71+
help='force use of train input size, even when test size is specified in pretrained cfg')
7072
parser.add_argument('--crop-pct', default=None, type=float,
7173
metavar='N', help='Input image center crop pct')
7274
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
@@ -164,10 +166,15 @@ def validate(args):
164166
param_count = sum([m.numel() for m in model.parameters()])
165167
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
166168

167-
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
169+
data_config = resolve_data_config(
170+
vars(args),
171+
model=model,
172+
use_test_size=not args.use_train_size,
173+
verbose=True
174+
)
168175
test_time_pool = False
169176
if args.test_pool:
170-
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
177+
model, test_time_pool = apply_test_time_pool(model, data_config)
171178

172179
if args.torchscript:
173180
torch.jit.optimized_execution(True)

0 commit comments

Comments
 (0)