Skip to content

Commit 4d20567

Browse files
committed
Mixup and prefetcher improvements
* Do mixup in custom collate fn if prefetcher enabled, reduces performance impact * Move mixup code to own file * Add arg to disable prefetcher * Fix no cuda transfer when prefetcher off * Random erasing when prefetcher off wasn't changed to match new args, fixed * Default random erasing to off (prob = 0.) for train
1 parent 780c0a9 commit 4d20567

File tree

6 files changed

+91
-29
lines changed

6 files changed

+91
-29
lines changed

data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from data.dataset import Dataset
44
from data.transforms import *
55
from data.loader import create_loader
6+
from data.mixup import mixup_target, FastCollateMixup

data/loader.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch.utils.data
22
from data.transforms import *
33
from data.distributed_sampler import OrderedDistributedSampler
4+
from data.mixup import FastCollateMixup
45

56

67
def fast_collate(batch):
@@ -60,6 +61,18 @@ def __len__(self):
6061
def sampler(self):
6162
return self.loader.sampler
6263

64+
@property
65+
def mixup_enabled(self):
66+
if isinstance(self.loader.collate_fn, FastCollateMixup):
67+
return self.loader.collate_fn.mixup_enabled
68+
else:
69+
return False
70+
71+
@mixup_enabled.setter
72+
def mixup_enabled(self, x):
73+
if isinstance(self.loader.collate_fn, FastCollateMixup):
74+
self.loader.collate_fn.mixup_enabled = x
75+
6376

6477
def create_loader(
6578
dataset,
@@ -75,6 +88,7 @@ def create_loader(
7588
num_workers=1,
7689
distributed=False,
7790
crop_pct=None,
91+
collate_fn=None,
7892
):
7993
if isinstance(input_size, tuple):
8094
img_size = input_size[-2:]
@@ -108,13 +122,16 @@ def create_loader(
108122
# of samples per-process, will slightly alter validation results
109123
sampler = OrderedDistributedSampler(dataset)
110124

125+
if collate_fn is None:
126+
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
127+
111128
loader = torch.utils.data.DataLoader(
112129
dataset,
113130
batch_size=batch_size,
114131
shuffle=sampler is None and is_training,
115132
num_workers=num_workers,
116133
sampler=sampler,
117-
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
134+
collate_fn=collate_fn,
118135
drop_last=is_training,
119136
)
120137
if use_prefetcher:

data/mixup.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
6+
x = x.long().view(-1, 1)
7+
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
8+
9+
10+
def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
11+
off_value = smoothing / num_classes
12+
on_value = 1. - smoothing + off_value
13+
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
14+
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
15+
return lam*y1 + (1. - lam)*y2
16+
17+
18+
class FastCollateMixup:
19+
20+
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
21+
self.mixup_alpha = mixup_alpha
22+
self.label_smoothing = label_smoothing
23+
self.num_classes = num_classes
24+
self.mixup_enabled = True
25+
26+
def __call__(self, batch):
27+
batch_size = len(batch)
28+
lam = 1.
29+
if self.mixup_enabled:
30+
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
31+
32+
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
33+
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
34+
35+
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
36+
for i in range(batch_size):
37+
mixed = batch[i][0].astype(np.float32) * lam + \
38+
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
39+
np.round(mixed, out=mixed)
40+
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
41+
42+
return tensor, target

data/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def transforms_imagenet_train(
159159
color_jitter=(0.4, 0.4, 0.4),
160160
interpolation='random',
161161
random_erasing=0.4,
162-
random_erasing_pp=True,
162+
random_erasing_mode='const',
163163
use_prefetcher=False,
164164
mean=IMAGENET_DEFAULT_MEAN,
165165
std=IMAGENET_DEFAULT_STD
@@ -183,7 +183,7 @@ def transforms_imagenet_train(
183183
std=torch.tensor(std))
184184
]
185185
if random_erasing > 0.:
186-
tfl.append(RandomErasing(random_erasing, per_pixel=random_erasing_pp, device='cpu'))
186+
tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
187187
return transforms.Compose(tfl)
188188

189189

train.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
except ImportError:
1111
has_apex = False
1212

13-
from data import Dataset, create_loader, resolve_data_config
13+
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
1414
from models import create_model, resume_checkpoint
1515
from utils import *
1616
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
@@ -66,9 +66,9 @@
6666
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
6767
help='LR scheduler (default: "step"')
6868
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
69-
help='Dropout rate (default: 0.1)')
70-
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
71-
help='Random erase prob (default: 0.4)')
69+
help='Dropout rate (default: 0.)')
70+
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
71+
help='Random erase prob (default: 0.)')
7272
parser.add_argument('--remode', type=str, default='const',
7373
help='Random erase mode (default: "const")')
7474
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
@@ -109,6 +109,8 @@
109109
help='save images of input bathes every log interval for debugging')
110110
parser.add_argument('--amp', action='store_true', default=False,
111111
help='use NVIDIA amp for mixed precision training')
112+
parser.add_argument('--no-prefetcher', action='store_true', default=False,
113+
help='disable fast prefetcher')
112114
parser.add_argument('--output', default='', type=str, metavar='PATH',
113115
help='path to output folder (default: none, current dir)')
114116
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
@@ -119,6 +121,7 @@
119121
def main():
120122
args = parser.parse_args()
121123

124+
args.prefetcher = not args.no_prefetcher
122125
args.distributed = False
123126
if 'WORLD_SIZE' in os.environ:
124127
args.distributed = int(os.environ['WORLD_SIZE']) > 1
@@ -130,6 +133,7 @@ def main():
130133
args.world_size = 1
131134
r = -1
132135
if args.distributed:
136+
args.num_gpu = 1
133137
args.device = 'cuda:%d' % args.local_rank
134138
torch.cuda.set_device(args.local_rank)
135139
torch.distributed.init_process_group(backend='nccl',
@@ -216,19 +220,24 @@ def main():
216220
exit(1)
217221
dataset_train = Dataset(train_dir)
218222

223+
collate_fn = None
224+
if args.prefetcher and args.mixup > 0:
225+
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
226+
219227
loader_train = create_loader(
220228
dataset_train,
221229
input_size=data_config['input_size'],
222230
batch_size=args.batch_size,
223231
is_training=True,
224-
use_prefetcher=True,
232+
use_prefetcher=args.prefetcher,
225233
rand_erase_prob=args.reprob,
226234
rand_erase_mode=args.remode,
227235
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
228236
mean=data_config['mean'],
229237
std=data_config['std'],
230238
num_workers=args.workers,
231239
distributed=args.distributed,
240+
collate_fn=collate_fn,
232241
)
233242

234243
eval_dir = os.path.join(args.data, 'validation')
@@ -242,7 +251,7 @@ def main():
242251
input_size=data_config['input_size'],
243252
batch_size=4 * args.batch_size,
244253
is_training=False,
245-
use_prefetcher=True,
254+
use_prefetcher=args.prefetcher,
246255
interpolation=data_config['interpolation'],
247256
mean=data_config['mean'],
248257
std=data_config['std'],
@@ -309,6 +318,10 @@ def train_epoch(
309318
epoch, model, loader, optimizer, loss_fn, args,
310319
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
311320

321+
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
322+
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
323+
loader.mixup_enabled = False
324+
312325
batch_time_m = AverageMeter()
313326
data_time_m = AverageMeter()
314327
losses_m = AverageMeter()
@@ -321,13 +334,15 @@ def train_epoch(
321334
for batch_idx, (input, target) in enumerate(loader):
322335
last_batch = batch_idx == last_idx
323336
data_time_m.update(time.time() - end)
324-
325-
if args.mixup > 0.:
326-
lam = 1.
327-
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
328-
lam = np.random.beta(args.mixup, args.mixup)
329-
input.mul_(lam).add_(1 - lam, input.flip(0))
330-
target = mixup_target(target, args.num_classes, lam, args.smoothing)
337+
if not args.prefetcher:
338+
input = input.cuda()
339+
target = target.cuda()
340+
if args.mixup > 0.:
341+
lam = 1.
342+
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
343+
lam = np.random.beta(args.mixup, args.mixup)
344+
input.mul_(lam).add_(1 - lam, input.flip(0))
345+
target = mixup_target(target, args.num_classes, lam, args.smoothing)
331346

332347
output = model(input)
333348

utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,6 @@ def accuracy(output, target, topk=(1,)):
140140
return res
141141

142142

143-
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
144-
x = x.long().view(-1, 1)
145-
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
146-
147-
148-
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
149-
off_value = smoothing / num_classes
150-
on_value = 1. - smoothing + off_value
151-
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
152-
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
153-
return lam*y1 + (1. - lam)*y2
154-
155-
156143
def get_outdir(path, *paths, inc=False):
157144
outdir = os.path.join(path, *paths)
158145
if not os.path.exists(outdir):

0 commit comments

Comments
 (0)