diff --git a/beginner_source/basics/optimization_tutorial.py b/beginner_source/basics/optimization_tutorial.py index 82bfaa8f07..63b1e1eff5 100644 --- a/beginner_source/basics/optimization_tutorial.py +++ b/beginner_source/basics/optimization_tutorial.py @@ -158,9 +158,9 @@ def train_loop(dataloader, model, loss_fn, optimizer): loss = loss_fn(pred, y) # Backpropagation + optimizer.zero_grad() loss.backward() optimizer.step() - optimizer.zero_grad() if batch % 100 == 0: loss, current = loss.item(), batch * batch_size + len(X)