@@ -16,6 +16,7 @@ limitations under the License.
1616
1717using System . Collections . Generic ;
1818using System . Linq ;
19+ using Tensorflow . Eager ;
1920using Tensorflow . Framework ;
2021using static Tensorflow . Binding ;
2122
@@ -82,7 +83,14 @@ private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_v
8283 . ToArray ( ) ;
8384
8485 var out_grads = new List < Tensor > ( ) ;
85- if ( constant_op . is_constant ( concat_dim ) )
86+ if ( concat_dim is EagerTensor )
87+ {
88+ var non_neg_concat_dim = ( int ) concat_dim % input_values [ 0 ] . rank ;
89+ var sizes = input_values . Select ( x => x . shape [ non_neg_concat_dim ] ) . ToArray ( ) ;
90+ var sizes_tensor = constant_op . constant ( sizes ) ;
91+ out_grads = gen_array_ops . split_v ( grad , sizes_tensor , sizes [ 0 ] , non_neg_concat_dim ) . ToList ( ) ;
92+ }
93+ else if ( constant_op . is_constant ( concat_dim ) )
8694 {
8795 /*If concat_dim is a constant defined in a different context,
8896 then we duplicate it in the current context to avoid passing it
@@ -97,33 +105,33 @@ through an Enter node.
97105 var value = tensor_util . constant_value ( concat_dim ) ;
98106 concat_dim = constant_op . constant ( value : value , dtype : concat_dim . dtype ) ;
99107 }
100- }
101108
102- // Using mod here for convenience since concat_dim is already verified
103- // in concat implementation to be within the allowed [-rank, rank) range.
104- var non_neg_concat_dim = concat_dim % array_ops . rank ( input_values [ 0 ] ) ;
109+ // Using mod here for convenience since concat_dim is already verified
110+ // in concat implementation to be within the allowed [-rank, rank) range.
111+ var non_neg_concat_dim = concat_dim % array_ops . rank ( input_values [ 0 ] ) ;
105112
106- // Get the inputs' tensor shapes
107- var sizes = _ExtractInputShapes ( input_values ) ;
113+ // Get the inputs' tensor shapes
114+ var sizes = _ExtractInputShapes ( input_values ) ;
108115
109- /* The magic number of 16 was found through benchmarking a range of sizes
110- on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
111- cases when switching implementations at N=16, but it is possible that
112- there will be a small number of performance regressions.*/
113- if ( len ( sizes ) > 16 )
114- {
115- // extract the size of each input along the concat dimension
116- var slice = array_ops . slice ( array_ops . stack ( sizes , axis : 1 ) ,
117- new Tensor [ ] { non_neg_concat_dim , tf . constant ( 0 ) } ,
118- new Tensor [ ] { tf . constant ( 1 ) , tf . constant ( - 1 ) } ) ;
119- var squeeze_sizes = array_ops . squeeze ( slice ) ;
120- out_grads = array_ops . split ( axis : grad , value : squeeze_sizes , num_split : ( int ) non_neg_concat_dim ) . ToList ( ) ;
121- }
122- else
123- {
124- var offset = gen_array_ops . concat_offset ( non_neg_concat_dim , sizes ) ;
125- foreach ( var ( begin , size ) in zip ( offset , sizes ) )
126- out_grads . Add ( gen_array_ops . slice ( grad , begin , size ) ) ;
116+ /* The magic number of 16 was found through benchmarking a range of sizes
117+ on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
118+ cases when switching implementations at N=16, but it is possible that
119+ there will be a small number of performance regressions.*/
120+ if ( len ( sizes ) > 16 )
121+ {
122+ // extract the size of each input along the concat dimension
123+ var slice = array_ops . slice ( array_ops . stack ( sizes , axis : 1 ) ,
124+ new Tensor [ ] { non_neg_concat_dim , tf . constant ( 0 ) } ,
125+ new Tensor [ ] { tf . constant ( 1 ) , tf . constant ( - 1 ) } ) ;
126+ var squeeze_sizes = array_ops . squeeze ( slice ) ;
127+ out_grads = array_ops . split ( axis : grad , value : squeeze_sizes , num_split : ( int ) non_neg_concat_dim ) . ToList ( ) ;
128+ }
129+ else
130+ {
131+ var offset = gen_array_ops . concat_offset ( non_neg_concat_dim , sizes ) ;
132+ foreach ( var ( begin , size ) in zip ( offset , sizes ) )
133+ out_grads . Add ( gen_array_ops . slice ( grad , begin , size ) ) ;
134+ }
127135 }
128136
129137 return ( end_value_index <= dim_index ?
0 commit comments