Add missing reset for extra_loss_arg tensor list to preserve reference to user data. #97
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR fixes an issue in handling the optional
extra_loss_argsarguments to the TorchFort supervised training function. There was a missing call to theresetroutine, which is needed to preserve references to the user data in cases where the TorchFort backend has to migrate the user data from CPU to GPU or GPU to CPU for a training step. Without this reset, user changes to the arrays passed totorchfort_tensor_list_add_tensorfor the extra loss arguments list would not propagate after the first training step. Workloads where the model and extra loss args data are already present on the same device are not impacted by this.I've adjusted the supervised training tests to better cover this scenario.