-
Notifications
You must be signed in to change notification settings - Fork 4
Description
In class Loss:
@property
def losses(self):
""" Gets individual losses (one for each batch example). """
if self._losses is None:
model_output = self.model.get_output(self.dataset.symb_inputs)
self._losses = self._compute_losses(model_output)
return self._losses
Bug example:
# Get train losses
train_loss = DummyLoss(model, trainset)
train_losses = train_loss.losses
# Get valid losses
valid_loss = DummyLoss(model, validset)
valid_losses = valid_loss.losses
The call to train_loss.losses triggers model.get_output, using trainset.symb_inputs
The model updates dictionary contains elements specific to trainset.symb_inputs
The call to valid_loss.losses triggers again model.get_output, this time using validset.symb_inputs
The model will update the updates dictionary, which now contains elements specific to both trainset.symb_inputs and validset.symb_inputs
The easy fix would be to force models to replace the updates dictionary. However, it would be preferable to find a better way of managing the updates and losses