@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414 limitations under the License.
1515******************************************************************************/
1616
17+ using System ;
1718using System . Linq ;
1819using Tensorflow . Contexts ;
1920using static Tensorflow . Binding ;
@@ -448,15 +449,34 @@ public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_IN
448449 return _op . outputs [ 0 ] ;
449450 }
450451
451- /// <summary>
452- /// Return a slice from 'input'
453- /// </summary>
454- /// <param name="input"></param>
455- /// <param name="begin"></param>
456- /// <param name="size"></param>
457- /// <param name="name"></param>
458- /// <returns></returns>
459- public static Tensor slice ( Tensor input , Tensor begin , Tensor size , string name = null )
452+ public static Tensor slice ( Tensor input , Tensor [ ] begin , Tensor [ ] size , string name = null )
453+ {
454+ if ( tf . executing_eagerly ( ) )
455+ {
456+ var result = slice_eager_fallback ( input , begin , size , name , tf . Context ) ;
457+ return result ;
458+ }
459+
460+ var _op = tf . OpDefLib . _apply_op_helper ( "Slice" , name , new { input , begin , size } ) ;
461+ return _op . outputs [ 0 ] ;
462+ }
463+
464+ private static Tensor slice_eager_fallback ( Tensor inputs , Tensor [ ] begin , Tensor [ ] size , string name , Context ctx )
465+ {
466+ var ( _attr_T , input ) = tf . Runner . ArgsToMatchingEager ( ctx , args : new [ ] { inputs } ) ;
467+ var ( _attr_Tidx , _inputs_Index ) = tf . Runner . ArgsToMatchingEager ( ctx , args : new object [ ] { begin , size } ) ;
468+ var _inputs_flat = input . concat ( _inputs_Index ) ;
469+ var _attrs = new object [ ] { "T" , _attr_T , "Index" , _attr_Tidx } ;
470+
471+ var results = tf . Runner . Execute ( ctx , "Slice" , 1 , _inputs_flat , _attrs , name : name ) ;
472+ if ( tf . Runner . MustRecordGradient ( ) )
473+ {
474+ tf . Runner . RecordGradient ( "Slice" , _inputs_flat , _attrs , results ) ;
475+ }
476+ return results [ 0 ] ;
477+ }
478+
479+ public static Tensor slice < Tb , Ts > ( Tensor input , Tb begin , Ts size , string name = null )
460480 {
461481 var _op = tf . OpDefLib . _apply_op_helper ( "Slice" , name , new { input , begin , size } ) ;
462482 return _op . outputs [ 0 ] ;
@@ -605,12 +625,6 @@ public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end,
605625 "shrink_axis_mask" , shrink_axis_mask ) . FirstOrDefault ( ) ,
606626 shape , begin , end , strides , dy ) ;
607627
608- public static Tensor slice < Tb , Ts > ( Tensor input , Tb begin , Ts size , string name = null )
609- {
610- var _op = tf . OpDefLib . _apply_op_helper ( "Slice" , name , new { input , begin , size } ) ;
611- return _op . outputs [ 0 ] ;
612- }
613-
614628 /// <summary>
615629 /// Removes dimensions of size 1 from the shape of a tensor.
616630 /// Given a tensor `input`, this operation returns a tensor of the same type with
0 commit comments