Skip to content

Commit b59f5a7

Browse files
committed
fix array_ops.slice #646
1 parent 630edea commit b59f5a7

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,9 @@ private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_
812812
return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name);
813813
}
814814

815+
public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null)
816+
=> gen_array_ops.slice(input, begin, size, name: name);
817+
815818
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
816819
=> gen_array_ops.slice(input, begin, size, name: name);
817820

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System;
1718
using System.Linq;
1819
using Tensorflow.Contexts;
1920
using static Tensorflow.Binding;
@@ -448,15 +449,34 @@ public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_IN
448449
return _op.outputs[0];
449450
}
450451

451-
/// <summary>
452-
/// Return a slice from 'input'
453-
/// </summary>
454-
/// <param name="input"></param>
455-
/// <param name="begin"></param>
456-
/// <param name="size"></param>
457-
/// <param name="name"></param>
458-
/// <returns></returns>
459-
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
452+
public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null)
453+
{
454+
if (tf.executing_eagerly())
455+
{
456+
var result = slice_eager_fallback(input, begin, size, name, tf.Context);
457+
return result;
458+
}
459+
460+
var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size });
461+
return _op.outputs[0];
462+
}
463+
464+
private static Tensor slice_eager_fallback(Tensor inputs, Tensor[] begin, Tensor[] size, string name, Context ctx)
465+
{
466+
var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new[] { inputs });
467+
var (_attr_Tidx, _inputs_Index) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { begin, size });
468+
var _inputs_flat = input.concat(_inputs_Index);
469+
var _attrs = new object[] { "T", _attr_T, "Index", _attr_Tidx };
470+
471+
var results = tf.Runner.Execute(ctx, "Slice", 1, _inputs_flat, _attrs, name: name);
472+
if (tf.Runner.MustRecordGradient())
473+
{
474+
tf.Runner.RecordGradient("Slice", _inputs_flat, _attrs, results);
475+
}
476+
return results[0];
477+
}
478+
479+
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
460480
{
461481
var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size });
462482
return _op.outputs[0];
@@ -605,12 +625,6 @@ public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end,
605625
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
606626
shape, begin, end, strides, dy);
607627

608-
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
609-
{
610-
var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size });
611-
return _op.outputs[0];
612-
}
613-
614628
/// <summary>
615629
/// Removes dimensions of size 1 from the shape of a tensor.
616630
/// Given a tensor `input`, this operation returns a tensor of the same type with

0 commit comments

Comments
 (0)