Skip to content

Using multiple losses/datasets on the same model grows the update dictionary #71

@ppoulin91

Description

@ppoulin91

In class Loss:

@property
def losses(self):
    """ Gets individual losses (one for each batch example). """
    if self._losses is None:
        model_output = self.model.get_output(self.dataset.symb_inputs)
        self._losses = self._compute_losses(model_output)

    return self._losses

Bug example:
# Get train losses
train_loss = DummyLoss(model, trainset)
train_losses = train_loss.losses

# Get valid losses
valid_loss = DummyLoss(model, validset)
valid_losses = valid_loss.losses

The call to train_loss.losses triggers model.get_output, using trainset.symb_inputs
The model updates dictionary contains elements specific to trainset.symb_inputs

The call to valid_loss.losses triggers again model.get_output, this time using validset.symb_inputs
The model will update the updates dictionary, which now contains elements specific to both trainset.symb_inputs and validset.symb_inputs

The easy fix would be to force models to replace the updates dictionary. However, it would be preferable to find a better way of managing the updates and losses

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions