@@ -212,8 +212,9 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
212212 } ;
213213 }
214214
215- var ( sx , sy ) = SmartBroadcastGradientArgs ( x , y ) ;
216- var ( rx , ry ) = gen_array_ops . broadcast_gradient_args ( sx , sy ) ;
215+ var broads = SmartBroadcastGradientArgs ( x , y ) ;
216+ var ( sx , rx , must_reduce_x ) = broads [ 0 ] ;
217+ var ( sy , ry , must_reduce_y ) = broads [ 1 ] ;
217218
218219 x = math_ops . conj ( x ) ;
219220 y = math_ops . conj ( y ) ;
@@ -222,33 +223,21 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
222223
223224 if ( op is EagerOperation op_eager1 &&
224225 op_eager1 . SkipInputIndices . Contains ( 0 ) )
225- {
226- return new Tensor [ ]
227- {
228- gen_math_ops . mul ( grad , math_ops . conj ( y ) ) ,
229- null
230- } ;
231- }
232- // else if not must_reduce_x:
233- // gx = gen_math_ops.mul(grad, y)
226+ gy = null ;
227+ else if ( ! must_reduce_x )
228+ gx = gen_math_ops . mul ( grad , y ) ;
234229 else
235- {
236230 gx = array_ops . reshape (
237231 math_ops . reduce_sum ( gen_math_ops . mul ( grad , y ) , rx ) , sx ) ;
238- }
239232
240233 if ( op is EagerOperation op_eager2 &&
241234 op_eager2 . SkipInputIndices . Contains ( 1 ) )
242- {
243-
244- }
245- // else if not must_reduce_y:
246- // gy = gen_math_ops.mul(x, grad)
235+ gy = null ;
236+ else if ( ! must_reduce_y )
237+ gy = gen_math_ops . mul ( x , grad ) ;
247238 else
248- {
249239 gy = array_ops . reshape (
250240 math_ops . reduce_sum ( gen_math_ops . mul ( x , grad ) , ry ) , sy ) ;
251- }
252241
253242 return new Tensor [ ] { gx , gy } ;
254243 }
@@ -479,8 +468,9 @@ public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
479468 _ShapesFullySpecifiedAndEqual ( x , y , grad ) )
480469 return new Tensor [ ] { grad , - grad } ;
481470
482- var ( sx , sy ) = SmartBroadcastGradientArgs ( x , y ) ;
483- var ( rx , ry ) = gen_array_ops . broadcast_gradient_args ( sx , sy ) ;
471+ var broads = SmartBroadcastGradientArgs ( x , y ) ;
472+ var ( sx , rx , must_reduce_x ) = broads [ 0 ] ;
473+ var ( sy , ry , must_reduce_y ) = broads [ 1 ] ;
484474
485475 var gx = array_ops . reshape ( math_ops . reduce_sum ( grad , rx ) , sx ) ;
486476 var gy = array_ops . reshape ( math_ops . reduce_sum ( - grad , ry ) , sy ) ;
@@ -728,8 +718,10 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
728718
729719 var z = op . outputs [ 0 ] ;
730720
731- var ( sx , sy ) = SmartBroadcastGradientArgs ( x , y ) ;
732- var ( rx , ry ) = gen_array_ops . broadcast_gradient_args ( sx , sy ) ;
721+ var broads = SmartBroadcastGradientArgs ( x , y ) ;
722+ var ( sx , rx , must_reduce_x ) = broads [ 0 ] ;
723+ var ( sy , ry , must_reduce_y ) = broads [ 1 ] ;
724+
733725 x = math_ops . conj ( x ) ;
734726 y = math_ops . conj ( y ) ;
735727 z = math_ops . conj ( z ) ;
@@ -761,22 +753,29 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
761753 /// <param name="x"></param>
762754 /// <param name="y"></param>
763755 /// <returns></returns>
764- private static ( Tensor , Tensor ) SmartBroadcastGradientArgs ( Tensor x , Tensor y )
756+ private static ( Tensor , Tensor , bool ) [ ] SmartBroadcastGradientArgs ( Tensor x , Tensor y )
765757 {
766758 Tensor sx , sy ;
767759 if ( x . TensorShape . is_fully_defined ( ) &&
768760 y . TensorShape . is_fully_defined ( ) )
769761 {
770762 sx = array_ops . shape ( x ) ;
771763 sy = array_ops . shape ( y ) ;
764+
765+ var ( rx , ry ) = gen_array_ops . broadcast_gradient_args ( sx , sy ) ;
766+ return new [ ]
767+ {
768+ ( sx , rx , true ) ,
769+ ( sy , ry , true )
770+ } ;
772771 }
773772 else
774773 {
775774 sx = array_ops . shape_internal ( x , optimize : false ) ;
776775 sy = array_ops . shape_internal ( y , optimize : false ) ;
777776 }
778777
779- return ( sx , sy ) ;
778+ throw new NotImplementedException ( "" ) ;
780779 }
781780 }
782781}
0 commit comments