Skip to content

Commit 91399b1

Browse files
committed
Fixed add GradientOperatorMulTest #642
1 parent 05561ea commit 91399b1

File tree

4 files changed

+30
-31
lines changed

4 files changed

+30
-31
lines changed

src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Linq;
1+
using System;
2+
using System.Linq;
23
using Tensorflow.Gradients;
34
using static Tensorflow.Binding;
45
using static Tensorflow.tensorflow;
@@ -37,7 +38,7 @@ public bool RecordGradient(string op_name,
3738
}*/
3839
}
3940

40-
// Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
41+
Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
4142
if (!should_record) return should_record;
4243

4344
Tensor[] op_outputs;

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,9 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
212212
};
213213
}
214214

215-
var (sx, sy) = SmartBroadcastGradientArgs(x, y);
216-
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
215+
var broads = SmartBroadcastGradientArgs(x, y);
216+
var (sx, rx, must_reduce_x) = broads[0];
217+
var (sy, ry, must_reduce_y) = broads[1];
217218

218219
x = math_ops.conj(x);
219220
y = math_ops.conj(y);
@@ -222,33 +223,21 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
222223

223224
if (op is EagerOperation op_eager1 &&
224225
op_eager1.SkipInputIndices.Contains(0))
225-
{
226-
return new Tensor[]
227-
{
228-
gen_math_ops.mul(grad, math_ops.conj(y)),
229-
null
230-
};
231-
}
232-
// else if not must_reduce_x:
233-
// gx = gen_math_ops.mul(grad, y)
226+
gy = null;
227+
else if (!must_reduce_x)
228+
gx = gen_math_ops.mul(grad, y);
234229
else
235-
{
236230
gx = array_ops.reshape(
237231
math_ops.reduce_sum(gen_math_ops.mul(grad, y), rx), sx);
238-
}
239232

240233
if (op is EagerOperation op_eager2 &&
241234
op_eager2.SkipInputIndices.Contains(1))
242-
{
243-
244-
}
245-
// else if not must_reduce_y:
246-
// gy = gen_math_ops.mul(x, grad)
235+
gy = null;
236+
else if (!must_reduce_y)
237+
gy = gen_math_ops.mul(x, grad);
247238
else
248-
{
249239
gy = array_ops.reshape(
250240
math_ops.reduce_sum(gen_math_ops.mul(x, grad), ry), sy);
251-
}
252241

253242
return new Tensor[] { gx, gy };
254243
}
@@ -479,8 +468,9 @@ public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
479468
_ShapesFullySpecifiedAndEqual(x, y, grad))
480469
return new Tensor[] { grad, -grad };
481470

482-
var (sx, sy) = SmartBroadcastGradientArgs(x, y);
483-
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
471+
var broads = SmartBroadcastGradientArgs(x, y);
472+
var (sx, rx, must_reduce_x) = broads[0];
473+
var (sy, ry, must_reduce_y) = broads[1];
484474

485475
var gx = array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
486476
var gy = array_ops.reshape(math_ops.reduce_sum(-grad, ry), sy);
@@ -728,8 +718,10 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
728718

729719
var z = op.outputs[0];
730720

731-
var (sx, sy) = SmartBroadcastGradientArgs(x, y);
732-
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
721+
var broads = SmartBroadcastGradientArgs(x, y);
722+
var (sx, rx, must_reduce_x) = broads[0];
723+
var (sy, ry, must_reduce_y) = broads[1];
724+
733725
x = math_ops.conj(x);
734726
y = math_ops.conj(y);
735727
z = math_ops.conj(z);
@@ -761,22 +753,29 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
761753
/// <param name="x"></param>
762754
/// <param name="y"></param>
763755
/// <returns></returns>
764-
private static (Tensor, Tensor) SmartBroadcastGradientArgs(Tensor x, Tensor y)
756+
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y)
765757
{
766758
Tensor sx, sy;
767759
if (x.TensorShape.is_fully_defined() &&
768760
y.TensorShape.is_fully_defined())
769761
{
770762
sx = array_ops.shape(x);
771763
sy = array_ops.shape(y);
764+
765+
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
766+
return new[]
767+
{
768+
(sx, rx, true),
769+
(sy, ry, true)
770+
};
772771
}
773772
else
774773
{
775774
sx = array_ops.shape_internal(x, optimize: false);
776775
sy = array_ops.shape_internal(y, optimize: false);
777776
}
778777

779-
return (sx, sy);
778+
throw new NotImplementedException("");
780779
}
781780
}
782781
}

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
1313
<PackageId>TensorFlow.Keras</PackageId>
1414
<PackageProjectUrl>https://github.com/SciSharp/TensorFlow.NET</PackageProjectUrl>
15-
<PackageIcon>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIcon>
15+
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
1616
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
1717
<PackageReleaseNotes>Keras for .NET is a C# version of Keras ported from the python version.</PackageReleaseNotes>
1818
<Description>Keras for .NET

test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ public void GradientOperatorMulTest()
4444
using var gt = tf.GradientTape();
4545
var y = x * w;
4646
var gr = gt.gradient(y, w);
47-
Assert.AreNotEqual(null, gr);
47+
Assert.AreEqual(new float[] { 0, 0 }, gr.numpy());
4848
}
49-
5049
}
5150
}

0 commit comments

Comments
 (0)