Skip to content

Commit fee607e

Browse files
committed
Mixup implemention in progress
* initial impl w/ label smoothing converging, but needs more testing
1 parent c3fbdd4 commit fee607e

File tree

5 files changed

+45
-5
lines changed

5 files changed

+45
-5
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
2929
* PNasNet (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch))
3030
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
3131
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
32-
* My generic MobileNet (GenMobileNet) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable, InvertedResidual, etc blocks
32+
* Generic MobileNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
3333
* MNASNet B1, A1 (Squeeze-Excite), and Small
3434
* MobileNet-V1
3535
* MobileNet-V2
@@ -49,7 +49,8 @@ Several (less common) features that I often utilize in my projects are included.
4949
* PyTorch w/ single GPU single process (AMP optional)
5050
* A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
5151
* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
52-
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Smoothed Softmax, etc)
52+
* Training schedules and techniques that provide competitive results (Cosine LR, Random Erasing, Label Smoothing, etc)
53+
* Mixup (as in https://arxiv.org/abs/1710.09412) - currently implementing/testing
5354
* An inference script that dumps output to CSV is provided as an example
5455

5556
### Custom Weights

loss/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from loss.cross_entropy import LabelSmoothingCrossEntropy
1+
from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy

loss/cross_entropy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import torch.nn as nn
23
import torch.nn.functional as F
34

@@ -24,3 +25,12 @@ def forward(self, x, target):
2425
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
2526
return loss.mean()
2627

28+
29+
class SparseLabelCrossEntropy(nn.Module):
30+
31+
def __init__(self):
32+
super(SparseLabelCrossEntropy, self).__init__()
33+
34+
def forward(self, x, target):
35+
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
36+
return loss.mean()

train.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from data import *
1414
from models import create_model, resume_checkpoint
1515
from utils import *
16-
from loss import LabelSmoothingCrossEntropy
16+
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
1717
from optim import create_optimizer
1818
from scheduler import create_scheduler
1919

@@ -79,6 +79,10 @@
7979
help='SGD momentum (default: 0.9)')
8080
parser.add_argument('--weight-decay', type=float, default=0.0001,
8181
help='weight decay (default: 0.0001)')
82+
parser.add_argument('--mixup', type=float, default=0.0,
83+
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
84+
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
85+
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
8286
parser.add_argument('--smoothing', type=float, default=0.1,
8387
help='label smoothing (default: 0.1)')
8488
parser.add_argument('--bn-tf', action='store_true', default=False,
@@ -246,7 +250,11 @@ def main():
246250
distributed=args.distributed,
247251
)
248252

249-
if args.smoothing:
253+
if args.mixup > 0.:
254+
# smoothing is handled with mixup label transform
255+
train_loss_fn = SparseLabelCrossEntropy().cuda()
256+
validate_loss_fn = nn.CrossEntropyLoss().cuda()
257+
elif args.smoothing:
250258
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
251259
validate_loss_fn = nn.CrossEntropyLoss().cuda()
252260
else:
@@ -314,6 +322,13 @@ def train_epoch(
314322
last_batch = batch_idx == last_idx
315323
data_time_m.update(time.time() - end)
316324

325+
if args.mixup > 0.:
326+
lam = 1.
327+
if not args.mixup_off_epoch or epoch < args.mixup_off_epoch:
328+
lam = np.random.beta(args.mixup, args.mixup)
329+
input.mul_(lam).add_(1 - lam, input.flip(0))
330+
target = mixup_target(target, args.num_classes, lam, args.smoothing)
331+
317332
output = model(input)
318333

319334
loss = loss_fn(output, target)

utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import glob
66
import csv
77
import operator
8+
import numpy as np
89
from collections import OrderedDict
910

1011

@@ -139,6 +140,19 @@ def accuracy(output, target, topk=(1,)):
139140
return res
140141

141142

143+
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
144+
x = x.long().view(-1, 1)
145+
return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
146+
147+
148+
def mixup_target(target, num_classes, lam=1., smoothing=0.0):
149+
off_value = smoothing / num_classes
150+
on_value = 1. - smoothing + off_value
151+
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value)
152+
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value)
153+
return lam*y1 + (1. - lam)*y2
154+
155+
142156
def get_outdir(path, *paths, inc=False):
143157
outdir = os.path.join(path, *paths)
144158
if not os.path.exists(outdir):

0 commit comments

Comments
 (0)