Skip to content

Commit e6c1442

Browse files
committed
More appropriate/correct loss name
1 parent 99ab1b1 commit e6c1442

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

loss/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from loss.cross_entropy import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
1+
from loss.cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

loss/cross_entropy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def forward(self, x, target):
2626
return loss.mean()
2727

2828

29-
class SparseLabelCrossEntropy(nn.Module):
29+
class SoftTargetCrossEntropy(nn.Module):
3030

3131
def __init__(self):
32-
super(SparseLabelCrossEntropy, self).__init__()
32+
super(SoftTargetCrossEntropy, self).__init__()
3333

3434
def forward(self, x, target):
3535
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
1414
from models import create_model, resume_checkpoint
1515
from utils import *
16-
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
16+
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
1717
from optim import create_optimizer
1818
from scheduler import create_scheduler
1919

@@ -261,7 +261,7 @@ def main():
261261

262262
if args.mixup > 0.:
263263
# smoothing is handled with mixup label transform
264-
train_loss_fn = SparseLabelCrossEntropy().cuda()
264+
train_loss_fn = SoftTargetCrossEntropy().cuda()
265265
validate_loss_fn = nn.CrossEntropyLoss().cuda()
266266
elif args.smoothing:
267267
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()

0 commit comments

Comments
 (0)