|
13 | 13 | from data import * |
14 | 14 | from models import create_model, resume_checkpoint |
15 | 15 | from utils import * |
16 | | -from loss import LabelSmoothingCrossEntropy |
| 16 | +from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy |
17 | 17 | from optim import create_optimizer |
18 | 18 | from scheduler import create_scheduler |
19 | 19 |
|
|
79 | 79 | help='SGD momentum (default: 0.9)') |
80 | 80 | parser.add_argument('--weight-decay', type=float, default=0.0001, |
81 | 81 | help='weight decay (default: 0.0001)') |
| 82 | +parser.add_argument('--mixup', type=float, default=0.0, |
| 83 | + help='mixup alpha, mixup enabled if > 0. (default: 0.)') |
| 84 | +parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', |
| 85 | + help='turn off mixup after this epoch, disabled if 0 (default: 0)') |
82 | 86 | parser.add_argument('--smoothing', type=float, default=0.1, |
83 | 87 | help='label smoothing (default: 0.1)') |
84 | 88 | parser.add_argument('--bn-tf', action='store_true', default=False, |
@@ -246,7 +250,11 @@ def main(): |
246 | 250 | distributed=args.distributed, |
247 | 251 | ) |
248 | 252 |
|
249 | | - if args.smoothing: |
| 253 | + if args.mixup > 0.: |
| 254 | + # smoothing is handled with mixup label transform |
| 255 | + train_loss_fn = SparseLabelCrossEntropy().cuda() |
| 256 | + validate_loss_fn = nn.CrossEntropyLoss().cuda() |
| 257 | + elif args.smoothing: |
250 | 258 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() |
251 | 259 | validate_loss_fn = nn.CrossEntropyLoss().cuda() |
252 | 260 | else: |
@@ -314,6 +322,13 @@ def train_epoch( |
314 | 322 | last_batch = batch_idx == last_idx |
315 | 323 | data_time_m.update(time.time() - end) |
316 | 324 |
|
| 325 | + if args.mixup > 0.: |
| 326 | + lam = 1. |
| 327 | + if not args.mixup_off_epoch or epoch < args.mixup_off_epoch: |
| 328 | + lam = np.random.beta(args.mixup, args.mixup) |
| 329 | + input.mul_(lam).add_(1 - lam, input.flip(0)) |
| 330 | + target = mixup_target(target, args.num_classes, lam, args.smoothing) |
| 331 | + |
317 | 332 | output = model(input) |
318 | 333 |
|
319 | 334 | loss = loss_fn(output, target) |
|
0 commit comments