From 758f3631f5cf4a1bd7fc6213f1eac3b35a48fb43 Mon Sep 17 00:00:00 2001 From: Eduardo Patrocinio Date: Wed, 26 Nov 2025 19:25:41 -0500 Subject: [PATCH] Fix pre/post-training evaluation to use same batch in nn_tutorial The tutorial was comparing loss on different batches: - Pre-training: evaluated on first 64 instances (batch 0) - Post-training: evaluated on last batch from training loop This made the comparison misleading as it wasn't measuring improvement on the same data. Changes: - Save the initial batch (xb_initial, yb_initial) after first evaluation - Use the saved initial batch for post-training evaluation - Added clarifying comment about fair comparison - Now both evaluations use the same data (first 64 training instances) This provides an accurate before/after comparison showing the model's improvement on the same batch of data. --- beginner_source/nn_tutorial.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/beginner_source/nn_tutorial.py b/beginner_source/nn_tutorial.py index e04815bd27e..4c0537101a7 100644 --- a/beginner_source/nn_tutorial.py +++ b/beginner_source/nn_tutorial.py @@ -174,6 +174,10 @@ def nll(input, target): yb = y_train[0:bs] print(loss_func(preds, yb)) +# Save the first batch for comparison after training +xb_initial = xb +yb_initial = yb + ############################################################################### # Let's also implement a function to calculate the accuracy of our model. @@ -244,9 +248,10 @@ def accuracy(out, yb): # # Let's check the loss and accuracy and compare those to what we got # earlier. We expect that the loss will have decreased and accuracy to -# have increased, and they have. +# have increased, and they have. We evaluate on the same initial batch +# we used before training for a fair comparison. -print(loss_func(model(xb), yb), accuracy(model(xb), yb)) +print(loss_func(model(xb_initial), yb_initial), accuracy(model(xb_initial), yb_initial)) ############################################################################### # Using ``torch.nn.functional``