Skip to content

Commit 2a8c4dc

Browse files
committed
Add validation script update for using test_input_size in model default_cfgs
1 parent 68a4144 commit 2a8c4dc

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

validate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ def validate(args):
147147
param_count = sum([m.numel() for m in model.parameters()])
148148
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
149149

150-
data_config = resolve_data_config(vars(args), model=model)
151-
model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, data_config)
150+
data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
151+
test_time_pool = False
152+
if not args.no_test_pool:
153+
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
152154

153155
if args.torchscript:
154156
torch.jit.optimized_execution(True)

0 commit comments

Comments
 (0)