Skip to content

Commit 0e1fd11

Browse files
authored
Merge pull request #12 from rwightman/ema-cleanup
Model weights Exponential Moving Average
2 parents 1019414 + 7dab6d1 commit 0e1fd11

File tree

6 files changed

+270
-97
lines changed

6 files changed

+270
-97
lines changed

clean_checkpoint.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
help='path to latest checkpoint (default: none)')
1010
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
1111
help='output path')
12+
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
13+
help='use ema version of weights if present')
1214

1315

1416
def main():
@@ -24,8 +26,13 @@ def main():
2426
checkpoint = torch.load(args.checkpoint, map_location='cpu')
2527

2628
new_state_dict = OrderedDict()
27-
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
28-
state_dict = checkpoint['state_dict']
29+
if isinstance(checkpoint, dict):
30+
state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
31+
if state_dict_key in checkpoint:
32+
state_dict = checkpoint[state_dict_key]
33+
else:
34+
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
35+
exit(1)
2936
else:
3037
state_dict = checkpoint
3138
for k, v in state_dict.items():

models/helpers.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,24 @@
44
from collections import OrderedDict
55

66

7-
def load_checkpoint(model, checkpoint_path):
7+
def load_checkpoint(model, checkpoint_path, use_ema=False):
88
if checkpoint_path and os.path.isfile(checkpoint_path):
9-
print("=> Loading checkpoint '{}'".format(checkpoint_path))
109
checkpoint = torch.load(checkpoint_path)
11-
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
10+
state_dict_key = ''
11+
if isinstance(checkpoint, dict):
12+
state_dict_key = 'state_dict'
13+
if use_ema and 'state_dict_ema' in checkpoint:
14+
state_dict_key = 'state_dict_ema'
15+
if state_dict_key and state_dict_key in checkpoint:
1216
new_state_dict = OrderedDict()
13-
for k, v in checkpoint['state_dict'].items():
14-
if k.startswith('module'):
15-
name = k[7:] # remove `module.`
16-
else:
17-
name = k
17+
for k, v in checkpoint[state_dict_key].items():
18+
# strip `module.` prefix
19+
name = k[7:] if k.startswith('module') else k
1820
new_state_dict[name] = v
1921
model.load_state_dict(new_state_dict)
2022
else:
2123
model.load_state_dict(checkpoint)
22-
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
24+
print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
2325
else:
2426
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
2527
raise FileNotFoundError()
@@ -28,27 +30,24 @@ def load_checkpoint(model, checkpoint_path):
2830
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
2931
optimizer_state = None
3032
if os.path.isfile(checkpoint_path):
31-
print("=> loading checkpoint '{}'".format(checkpoint_path))
3233
checkpoint = torch.load(checkpoint_path)
3334
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
3435
new_state_dict = OrderedDict()
3536
for k, v in checkpoint['state_dict'].items():
36-
if k.startswith('module'):
37-
name = k[7:] # remove `module.`
38-
else:
39-
name = k
37+
name = k[7:] if k.startswith('module') else k
4038
new_state_dict[name] = v
4139
model.load_state_dict(new_state_dict)
4240
if 'optimizer' in checkpoint:
4341
optimizer_state = checkpoint['optimizer']
44-
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
4542
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
43+
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
4644
else:
4745
model.load_state_dict(checkpoint)
4846
start_epoch = 0 if start_epoch is None else start_epoch
47+
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
4948
return optimizer_state, start_epoch
5049
else:
51-
print("=> No checkpoint found at '{}'".format(checkpoint_path))
50+
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
5251
raise FileNotFoundError()
5352

5453

optim/rmsprop_tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def step(self, closure=None):
8989
state['step'] += 1
9090

9191
if group['weight_decay'] != 0:
92-
if group['decoupled_decay']:
92+
if 'decoupled_decay' in group and group['decoupled_decay']:
9393
p.data.add_(-group['weight_decay'], p.data)
9494
else:
9595
grad = grad.add(group['weight_decay'], p.data)
@@ -109,7 +109,7 @@ def step(self, closure=None):
109109
if group['momentum'] > 0:
110110
buf = state['momentum_buffer']
111111
# Tensorflow accumulates the LR scaling in the momentum buffer
112-
if group['lr_in_momentum']:
112+
if 'lr_in_momentum' in group and group['lr_in_momentum']:
113113
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
114114
p.data.add_(-buf)
115115
else:

train.py

Lines changed: 70 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
try:
77
from apex import amp
88
from apex.parallel import DistributedDataParallel as DDP
9+
from apex.parallel import convert_syncbn_model
910
has_apex = True
1011
except ImportError:
1112
has_apex = False
1213

1314
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
14-
from models import create_model, resume_checkpoint
15+
from models import create_model, resume_checkpoint, load_checkpoint
1516
from utils import *
1617
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
1718
from optim import create_optimizer
@@ -41,8 +42,8 @@
4142
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
4243
parser.add_argument('--pretrained', action='store_true', default=False,
4344
help='Start with pretrained version of specified network (if avail)')
44-
parser.add_argument('--img-size', type=int, default=224, metavar='N',
45-
help='Image patch size (default: 224)')
45+
parser.add_argument('--img-size', type=int, default=None, metavar='N',
46+
help='Image patch size (default: None => model default)')
4647
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
4748
help='Override mean pixel value of dataset')
4849
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@@ -91,11 +92,17 @@
9192
help='BatchNorm momentum override (if not None)')
9293
parser.add_argument('--bn-eps', type=float, default=None,
9394
help='BatchNorm epsilon override (if not None)')
95+
parser.add_argument('--model-ema', action='store_true', default=False,
96+
help='Enable tracking moving average of model weights')
97+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
98+
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
99+
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
100+
help='decay factor for model weights moving average (default: 0.9998)')
94101
parser.add_argument('--seed', type=int, default=42, metavar='S',
95102
help='random seed (default: 42)')
96103
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
97104
help='how many batches to wait before logging training status')
98-
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
105+
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
99106
help='how many batches to wait before writing recovery checkpoint')
100107
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
101108
help='how many training processes to use (default: 1)')
@@ -109,6 +116,8 @@
109116
help='save images of input bathes every log interval for debugging')
110117
parser.add_argument('--amp', action='store_true', default=False,
111118
help='use NVIDIA amp for mixed precision training')
119+
parser.add_argument('--sync-bn', action='store_true',
120+
help='enabling apex sync BN.')
112121
parser.add_argument('--no-prefetcher', action='store_true', default=False,
113122
help='disable fast prefetcher')
114123
parser.add_argument('--output', default='', type=str, metavar='PATH',
@@ -131,36 +140,24 @@ def main():
131140

132141
args.device = 'cuda:0'
133142
args.world_size = 1
134-
r = -1
143+
args.rank = 0 # global rank
135144
if args.distributed:
136145
args.num_gpu = 1
137146
args.device = 'cuda:%d' % args.local_rank
138147
torch.cuda.set_device(args.local_rank)
139-
torch.distributed.init_process_group(backend='nccl',
140-
init_method='env://')
148+
torch.distributed.init_process_group(
149+
backend='nccl', init_method='env://')
141150
args.world_size = torch.distributed.get_world_size()
142-
r = torch.distributed.get_rank()
151+
args.rank = torch.distributed.get_rank()
152+
assert args.rank >= 0
143153

144154
if args.distributed:
145155
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
146-
% (r, args.world_size))
156+
% (args.rank, args.world_size))
147157
else:
148158
print('Training with a single process on %d GPUs.' % args.num_gpu)
149159

150-
# FIXME seed handling for multi-process distributed?
151-
torch.manual_seed(args.seed)
152-
153-
output_dir = ''
154-
if args.local_rank == 0:
155-
if args.output:
156-
output_base = args.output
157-
else:
158-
output_base = './output'
159-
exp_name = '-'.join([
160-
datetime.now().strftime("%Y%m%d-%H%M%S"),
161-
args.model,
162-
str(args.img_size)])
163-
output_dir = get_outdir(output_base, 'train', exp_name)
160+
torch.manual_seed(args.seed + args.rank)
164161

165162
model = create_model(
166163
args.model,
@@ -191,6 +188,8 @@ def main():
191188
args.amp = False
192189
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
193190
else:
191+
if args.distributed and args.sync_bn and has_apex:
192+
model = convert_syncbn_model(model)
194193
model.cuda()
195194

196195
optimizer = create_optimizer(args, model)
@@ -205,8 +204,20 @@ def main():
205204
use_amp = False
206205
print('AMP disabled')
207206

207+
model_ema = None
208+
if args.model_ema:
209+
model_ema = ModelEma(
210+
model,
211+
decay=args.model_ema_decay,
212+
device='cpu' if args.model_ema_force_cpu else '',
213+
resume=args.resume)
214+
208215
if args.distributed:
209216
model = DDP(model, delay_allreduce=True)
217+
if model_ema is not None and not args.model_ema_force_cpu:
218+
# must also distribute EMA model to allow validation
219+
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
220+
model_ema.ema_has_module = True
210221

211222
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
212223
if start_epoch > 0:
@@ -271,23 +282,37 @@ def main():
271282
validate_loss_fn = train_loss_fn
272283

273284
eval_metric = args.eval_metric
285+
best_metric = None
286+
best_epoch = None
274287
saver = None
275-
if output_dir:
288+
output_dir = ''
289+
if args.local_rank == 0:
290+
output_base = args.output if args.output else './output'
291+
exp_name = '-'.join([
292+
datetime.now().strftime("%Y%m%d-%H%M%S"),
293+
args.model,
294+
str(data_config['input_size'][-1])
295+
])
296+
output_dir = get_outdir(output_base, 'train', exp_name)
276297
decreasing = True if eval_metric == 'loss' else False
277298
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
278-
best_metric = None
279-
best_epoch = None
299+
280300
try:
281301
for epoch in range(start_epoch, num_epochs):
282302
if args.distributed:
283303
loader_train.sampler.set_epoch(epoch)
284304

285305
train_metrics = train_epoch(
286306
epoch, model, loader_train, optimizer, train_loss_fn, args,
287-
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
307+
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
308+
use_amp=use_amp, model_ema=model_ema)
309+
310+
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
288311

289-
eval_metrics = validate(
290-
model, loader_eval, validate_loss_fn, args)
312+
if model_ema is not None and not args.model_ema_force_cpu:
313+
ema_eval_metrics = validate(
314+
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
315+
eval_metrics = ema_eval_metrics
291316

292317
if lr_scheduler is not None:
293318
lr_scheduler.step(epoch, eval_metrics[eval_metric])
@@ -298,15 +323,12 @@ def main():
298323

299324
if saver is not None:
300325
# save proper checkpoint with eval metric
301-
best_metric, best_epoch = saver.save_checkpoint({
302-
'epoch': epoch + 1,
303-
'arch': args.model,
304-
'state_dict': model.state_dict(),
305-
'optimizer': optimizer.state_dict(),
306-
'args': args,
307-
},
326+
save_metric = eval_metrics[eval_metric]
327+
best_metric, best_epoch = saver.save_checkpoint(
328+
model, optimizer, args,
308329
epoch=epoch + 1,
309-
metric=eval_metrics[eval_metric])
330+
model_ema=model_ema,
331+
metric=save_metric)
310332

311333
except KeyboardInterrupt:
312334
pass
@@ -316,7 +338,7 @@ def main():
316338

317339
def train_epoch(
318340
epoch, model, loader, optimizer, loss_fn, args,
319-
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
341+
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
320342

321343
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
322344
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
@@ -359,6 +381,8 @@ def train_epoch(
359381
optimizer.step()
360382

361383
torch.cuda.synchronize()
384+
if model_ema is not None:
385+
model_ema.update(model)
362386
num_updates += 1
363387

364388
batch_time_m.update(time.time() - end)
@@ -394,18 +418,11 @@ def train_epoch(
394418
padding=0,
395419
normalize=True)
396420

397-
if args.local_rank == 0 and (
398-
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
421+
if saver is not None and args.recovery_interval and (
422+
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
399423
save_epoch = epoch + 1 if last_batch else epoch
400-
saver.save_recovery({
401-
'epoch': save_epoch,
402-
'arch': args.model,
403-
'state_dict': model.state_dict(),
404-
'optimizer': optimizer.state_dict(),
405-
'args': args,
406-
},
407-
epoch=save_epoch,
408-
batch_idx=batch_idx)
424+
saver.save_recovery(
425+
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
409426

410427
if lr_scheduler is not None:
411428
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
@@ -415,7 +432,7 @@ def train_epoch(
415432
return OrderedDict([('loss', losses_m.avg)])
416433

417434

418-
def validate(model, loader, loss_fn, args):
435+
def validate(model, loader, loss_fn, args, log_suffix=''):
419436
batch_time_m = AverageMeter()
420437
losses_m = AverageMeter()
421438
prec1_m = AverageMeter()
@@ -461,12 +478,13 @@ def validate(model, loader, loss_fn, args):
461478
batch_time_m.update(time.time() - end)
462479
end = time.time()
463480
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
464-
print('Test: [{0}/{1}]\t'
481+
log_name = 'Test' + log_suffix
482+
print('{0}: [{1}/{2}]\t'
465483
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
466484
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
467485
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
468486
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
469-
batch_idx, last_idx,
487+
log_name, batch_idx, last_idx,
470488
batch_time=batch_time_m, loss=losses_m,
471489
top1=prec1_m, top5=prec5_m))
472490

@@ -475,12 +493,5 @@ def validate(model, loader, loss_fn, args):
475493
return metrics
476494

477495

478-
def reduce_tensor(tensor, n):
479-
rt = tensor.clone()
480-
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
481-
rt /= n
482-
return rt
483-
484-
485496
if __name__ == '__main__':
486497
main()

0 commit comments

Comments
 (0)