Skip to content

Commit 1623bb8

Browse files
committed
fix _assign_new_value for BatchNormaliztion.
1 parent ae9a161 commit 1623bb8

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/TensorFlowNET.Core/Gradients/TapeTensor.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,8 @@ public Tensor ZerosLike()
2222

2323
public Tensor OnesLike()
2424
=> tf.ones(shape: shape, dtype: dtype);
25+
26+
public override string ToString()
27+
=> $"{id}, {shape}, {dtype.as_numpy_name()}";
2528
}
2629
}

src/TensorFlowNET.Keras/Layers/BatchNormalization.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
196196
_assign_moving_average(moving_variance, variance, momentum_tensor);
197197

198198
if (use_fused_avg_updates)
199-
_assign_new_value(moving_variance, mean);
199+
_assign_new_value(moving_variance, variance);
200200
else
201201
_assign_moving_average(moving_variance, variance, momentum_tensor);
202202

0 commit comments

Comments
 (0)