Skip to content

Commit 374ff58

Browse files
committed
keras.backend.resize_images
1 parent 3e76c19 commit 374ff58

File tree

4 files changed

+68
-4
lines changed

4 files changed

+68
-4
lines changed

src/TensorFlowNET.Core/APIs/tf.image.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ public Tensor resize_images(Tensor images, Tensor size, string method = ResizeMe
6060

6161
public Tensor resize_images_v2(Tensor images, TensorShape size, string method = ResizeMethod.BILINEAR, bool preserve_aspect_ratio = false, bool antialias = false,
6262
string name = null)
63-
=> image_ops_impl.resize_images(images, tf.constant(size.dims), method, preserve_aspect_ratio, antialias, name);
63+
=> image_ops_impl.resize_images_v2(images, size, method, preserve_aspect_ratio, antialias, name);
64+
65+
public Tensor resize_images_v2(Tensor images, Tensor size, string method = ResizeMethod.BILINEAR, bool preserve_aspect_ratio = false, bool antialias = false,
66+
string name = null)
67+
=> image_ops_impl.resize_images_v2(images, size, method, preserve_aspect_ratio, antialias, name);
6468

6569
public Tensor resize_images_with_pad(Tensor image, int target_height, int target_width, string method, bool antialias)
6670
=> image_ops_impl.resize_images_with_pad(image, target_height, target_width, method, antialias);
@@ -209,6 +213,9 @@ public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool ce
209213
public Tensor resize(Tensor image, TensorShape size, string method = ResizeMethod.BILINEAR)
210214
=> image_ops_impl.resize_images_v2(image, size, method: method);
211215

216+
public Tensor resize(Tensor image, Tensor size, string method = ResizeMethod.BILINEAR)
217+
=> image_ops_impl.resize_images_v2(image, size, method: method);
218+
212219
public Tensor resize_bilinear(Tensor images, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = null)
213220
=> gen_image_ops.resize_bilinear(images, size, align_corners: align_corners, half_pixel_centers: half_pixel_centers, name: name);
214221

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,18 +175,32 @@ public static Tensor pad(Tensor input, Tensor paddings, string name = null)
175175
{
176176
if (tf.Context.executing_eagerly())
177177
{
178-
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
178+
/*var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
179179
"Pad", name,
180180
null,
181181
input, paddings);
182-
return results[0];
182+
return results[0];*/
183+
return pad_eager_fallback(input, paddings, name: name, ctx: tf.Context);
183184
}
184185

185186
var _op = tf.OpDefLib._apply_op_helper("Pad", name: name, args: new { input, paddings });
186187

187188
return _op.output;
188189
}
189190

191+
private static Tensor pad_eager_fallback(Tensor inputs, Tensor padding, string name = null, Context ctx = null)
192+
{
193+
var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new[] { inputs });
194+
var (_attr_Tpaddings, paddings) = tf.Runner.ArgsToMatchingEager(ctx, default_dtype: tf.int32, args: new[] { padding });
195+
var _inputs_flat = input.concat(paddings);
196+
var _attrs = new object[] { "T", _attr_T, "Tpaddings", _attr_Tpaddings };
197+
198+
var results = tf.Runner.Execute(ctx, "Pad", 1, _inputs_flat, _attrs, name: name);
199+
if (tf.Runner.MustRecordGradient())
200+
tf.Runner.RecordGradient("Pad", _inputs_flat, _attrs, results);
201+
return results[0];
202+
}
203+
190204
public static Tensor pack(Tensor[] values, int axis = 0, string name = null)
191205
{
192206
if (tf.Context.executing_eagerly())

src/TensorFlowNET.Core/Operations/image_ops_impl.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2210,7 +2210,7 @@ public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool s
22102210
/// <param name="antialias"></param>
22112211
/// <param name="name"></param>
22122212
/// <returns></returns>
2213-
public static Tensor resize_images_v2(Tensor images, TensorShape size, string method = ResizeMethod.BILINEAR,
2213+
public static Tensor resize_images_v2<T>(Tensor images, T size, string method = ResizeMethod.BILINEAR,
22142214
bool preserve_aspect_ratio = false,
22152215
bool antialias = false,
22162216
string name = null)

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,48 @@ public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_l
183183

184184
throw new NotImplementedException("");
185185
}
186+
187+
/// <summary>
188+
/// Resizes the images contained in a 4D tensor.
189+
/// </summary>
190+
/// <param name="x"></param>
191+
/// <param name="height_factor"></param>
192+
/// <param name="width_factor"></param>
193+
/// <param name="data_format"></param>
194+
/// <param name="interpolation"></param>
195+
/// <returns></returns>
196+
public Tensor resize_images(Tensor x, int height_factor, int width_factor,
197+
string data_format, string interpolation = "nearest")
198+
{
199+
var (rows, cols) = (0, 0);
200+
if (data_format == "channels_first")
201+
(rows, cols) = (2, 3);
202+
else if (data_format == "channels_last")
203+
(rows, cols) = (1, 2);
204+
else
205+
throw new ValueError($"Invalid `data_format` argument: {data_format}");
206+
207+
var original_shape = x.shape;
208+
var new_shape = array_ops.shape(x)[new Slice(rows, cols + 1)];
209+
new_shape *= constant_op.constant(np.array(height_factor, width_factor));
210+
211+
if (data_format == "channels_first")
212+
// x = permute_dimensions(x, [0, 2, 3, 1]);
213+
throw new NotImplementedException("");
214+
if (interpolation == "nearest")
215+
x = tf.image.resize_images_v2(x, new_shape, method: ResizeMethod.NEAREST_NEIGHBOR);
216+
217+
if (data_format == "channels_first")
218+
// x = permute_dimensions(x, [0, 3, 1, 2]);
219+
throw new NotImplementedException("");
220+
221+
int new_height = original_shape[rows] < 0 ? -1 : original_shape[rows] * height_factor;
222+
int new_width = original_shape[cols] < 0 ? -1 : original_shape[cols] * width_factor;
223+
224+
TensorShape output_shape = data_format == "channels_first" ?
225+
(-1, -1, new_height, new_width) : (-1, new_height, new_width, -1);
226+
x.set_shape(output_shape);
227+
return x;
228+
}
186229
}
187230
}

0 commit comments

Comments
 (0)