-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Labels
Description
In class LossView:
def __init__(self, loss, batch_scheduler):
super().__init__()
self.batch_scheduler = batch_scheduler
# Gather updates from the optimizer and the batch scheduler.
graph_updates = OrderedDict()
graph_updates.update(loss.updates)
graph_updates.update(batch_scheduler.updates)
self.compute_loss = theano.function([],
loss.losses,
updates=graph_updates,
givens=batch_scheduler.givens,
name="compute_loss")
loss.updates is called before loss.losses while it should be the opposite
Explanation:
loss.losses triggers a call to model.get_output, which will update the model.updates dictionary.
loss.updates should fetch this new version of the dictionary instead of the old precomputed dictionary, which may contain updates specific to another dataset
Using the model in an experiment after using multiple LossView objects on different datasets will trigger a MissingInputError on trainset.symb_inputs