Skip to content

Commit 37c71a5

Browse files
committed
Some further create_optimizer_v2 tweaks, remove some redudnant code, add back safe model str. Benchmark step times per batch.
1 parent 2bb65bd commit 37c71a5

File tree

3 files changed

+41
-38
lines changed

3 files changed

+41
-38
lines changed

benchmark.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -217,25 +217,26 @@ def _step():
217217
delta_fwd = _step()
218218
total_step += delta_fwd
219219
num_samples += self.batch_size
220-
if (i + 1) % self.log_freq == 0:
220+
num_steps = i + 1
221+
if num_steps % self.log_freq == 0:
221222
_logger.info(
222-
f"Infer [{i + 1}/{self.num_bench_iter}]."
223+
f"Infer [{num_steps}/{self.num_bench_iter}]."
223224
f" {num_samples / total_step:0.2f} samples/sec."
224-
f" {1000 * total_step / num_samples:0.3f} ms/sample.")
225+
f" {1000 * total_step / num_steps:0.3f} ms/step.")
225226
t_run_end = self.time_fn(True)
226227
t_run_elapsed = t_run_end - t_run_start
227228

228229
results = dict(
229230
samples_per_sec=round(num_samples / t_run_elapsed, 2),
230-
step_time=round(1000 * total_step / num_samples, 3),
231+
step_time=round(1000 * total_step / self.num_bench_iter, 3),
231232
batch_size=self.batch_size,
232233
img_size=self.input_size[-1],
233234
param_count=round(self.param_count / 1e6, 2),
234235
)
235236

236237
_logger.info(
237238
f"Inference benchmark of {self.model_name} done. "
238-
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample")
239+
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step")
239240

240241
return results
241242

@@ -254,8 +255,8 @@ def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
254255

255256
self.optimizer = create_optimizer_v2(
256257
self.model,
257-
opt_name=kwargs.pop('opt', 'sgd'),
258-
lr=kwargs.pop('lr', 1e-4))
258+
optimizer_name=kwargs.pop('opt', 'sgd'),
259+
learning_rate=kwargs.pop('lr', 1e-4))
259260

260261
def _gen_target(self, batch_size):
261262
return torch.empty(
@@ -309,23 +310,24 @@ def _step(detail=False):
309310
total_fwd += delta_fwd
310311
total_bwd += delta_bwd
311312
total_opt += delta_opt
312-
if (i + 1) % self.log_freq == 0:
313+
num_steps = (i + 1)
314+
if num_steps % self.log_freq == 0:
313315
total_step = total_fwd + total_bwd + total_opt
314316
_logger.info(
315-
f"Train [{i + 1}/{self.num_bench_iter}]."
317+
f"Train [{num_steps}/{self.num_bench_iter}]."
316318
f" {num_samples / total_step:0.2f} samples/sec."
317-
f" {1000 * total_fwd / num_samples:0.3f} ms/sample fwd,"
318-
f" {1000 * total_bwd / num_samples:0.3f} ms/sample bwd,"
319-
f" {1000 * total_opt / num_samples:0.3f} ms/sample opt."
319+
f" {1000 * total_fwd / num_steps:0.3f} ms/step fwd,"
320+
f" {1000 * total_bwd / num_steps:0.3f} ms/step bwd,"
321+
f" {1000 * total_opt / num_steps:0.3f} ms/step opt."
320322
)
321323
total_step = total_fwd + total_bwd + total_opt
322324
t_run_elapsed = self.time_fn() - t_run_start
323325
results = dict(
324326
samples_per_sec=round(num_samples / t_run_elapsed, 2),
325-
step_time=round(1000 * total_step / num_samples, 3),
326-
fwd_time=round(1000 * total_fwd / num_samples, 3),
327-
bwd_time=round(1000 * total_bwd / num_samples, 3),
328-
opt_time=round(1000 * total_opt / num_samples, 3),
327+
step_time=round(1000 * total_step / self.num_bench_iter, 3),
328+
fwd_time=round(1000 * total_fwd / self.num_bench_iter, 3),
329+
bwd_time=round(1000 * total_bwd / self.num_bench_iter, 3),
330+
opt_time=round(1000 * total_opt / self.num_bench_iter, 3),
329331
batch_size=self.batch_size,
330332
img_size=self.input_size[-1],
331333
param_count=round(self.param_count / 1e6, 2),
@@ -337,15 +339,16 @@ def _step(detail=False):
337339
delta_step = _step(False)
338340
num_samples += self.batch_size
339341
total_step += delta_step
340-
if (i + 1) % self.log_freq == 0:
342+
num_steps = (i + 1)
343+
if num_steps % self.log_freq == 0:
341344
_logger.info(
342-
f"Train [{i + 1}/{self.num_bench_iter}]."
345+
f"Train [{num_steps}/{self.num_bench_iter}]."
343346
f" {num_samples / total_step:0.2f} samples/sec."
344-
f" {1000 * total_step / num_samples:0.3f} ms/sample.")
347+
f" {1000 * total_step / num_steps:0.3f} ms/step.")
345348
t_run_elapsed = self.time_fn() - t_run_start
346349
results = dict(
347350
samples_per_sec=round(num_samples / t_run_elapsed, 2),
348-
step_time=round(1000 * total_step / num_samples, 3),
351+
step_time=round(1000 * total_step / self.num_bench_iter, 3),
349352
batch_size=self.batch_size,
350353
img_size=self.input_size[-1],
351354
param_count=round(self.param_count / 1e6, 2),

timm/optim/optim_factory.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,35 +44,35 @@ def optimizer_kwargs(cfg):
4444
""" cfg/argparse to kwargs helper
4545
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
4646
"""
47-
kwargs = dict(opt_name=cfg.opt, lr=cfg.lr, weight_decay=cfg.weight_decay)
47+
kwargs = dict(
48+
optimizer_name=cfg.opt,
49+
learning_rate=cfg.lr,
50+
weight_decay=cfg.weight_decay,
51+
momentum=cfg.momentum)
4852
if getattr(cfg, 'opt_eps', None) is not None:
4953
kwargs['eps'] = cfg.opt_eps
5054
if getattr(cfg, 'opt_betas', None) is not None:
5155
kwargs['betas'] = cfg.opt_betas
5256
if getattr(cfg, 'opt_args', None) is not None:
5357
kwargs.update(cfg.opt_args)
54-
kwargs['momentum'] = cfg.momentum
5558
return kwargs
5659

5760

5861
def create_optimizer(args, model, filter_bias_and_bn=True):
5962
""" Legacy optimizer factory for backwards compatibility.
6063
NOTE: Use create_optimizer_v2 for new code.
6164
"""
62-
opt_args = dict(lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
63-
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
64-
opt_args['eps'] = args.opt_eps
65-
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
66-
opt_args['betas'] = args.opt_betas
67-
if hasattr(args, 'opt_args') and args.opt_args is not None:
68-
opt_args.update(args.opt_args)
69-
return create_optimizer_v2(model, opt_name=args.opt, filter_bias_and_bn=filter_bias_and_bn, **opt_args)
65+
return create_optimizer_v2(
66+
model,
67+
**optimizer_kwargs(cfg=args),
68+
filter_bias_and_bn=filter_bias_and_bn,
69+
)
7070

7171

7272
def create_optimizer_v2(
7373
model: nn.Module,
74-
opt_name: str = 'sgd',
75-
lr: Optional[float] = None,
74+
optimizer_name: str = 'sgd',
75+
learning_rate: Optional[float] = None,
7676
weight_decay: float = 0.,
7777
momentum: float = 0.9,
7878
filter_bias_and_bn: bool = True,
@@ -86,8 +86,8 @@ def create_optimizer_v2(
8686
8787
Args:
8888
model (nn.Module): model containing parameters to optimize
89-
opt_name: name of optimizer to create
90-
lr: initial learning rate
89+
optimizer_name: name of optimizer to create
90+
learning_rate: initial learning rate
9191
weight_decay: weight decay to apply in optimizer
9292
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
9393
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
@@ -96,7 +96,7 @@ def create_optimizer_v2(
9696
Returns:
9797
Optimizer
9898
"""
99-
opt_lower = opt_name.lower()
99+
opt_lower = optimizer_name.lower()
100100
if weight_decay and filter_bias_and_bn:
101101
skip = {}
102102
if hasattr(model, 'no_weight_decay'):
@@ -108,7 +108,7 @@ def create_optimizer_v2(
108108
if 'fused' in opt_lower:
109109
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
110110

111-
opt_args = dict(lr=lr, weight_decay=weight_decay, **kwargs)
111+
opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs)
112112
opt_split = opt_lower.split('_')
113113
opt_lower = opt_split[-1]
114114
if opt_lower == 'sgd' or opt_lower == 'nesterov':
@@ -132,7 +132,7 @@ def create_optimizer_v2(
132132
elif opt_lower == 'adadelta':
133133
optimizer = optim.Adadelta(parameters, **opt_args)
134134
elif opt_lower == 'adafactor':
135-
if not lr:
135+
if not learning_rate:
136136
opt_args['lr'] = None
137137
optimizer = Adafactor(parameters, **opt_args)
138138
elif opt_lower == 'adahessian':

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def main():
552552
else:
553553
exp_name = '-'.join([
554554
datetime.now().strftime("%Y%m%d-%H%M%S"),
555-
args.model,
555+
safe_model_name(args.model),
556556
str(data_config['input_size'][-1])
557557
])
558558
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)

0 commit comments

Comments
 (0)