Skip to content

Commit 94ad2f5

Browse files
committed
Add layers of ZeroPadding2D and UpSampling2D.
1 parent 374ff58 commit 94ad2f5

File tree

8 files changed

+119
-14
lines changed

8 files changed

+119
-14
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
namespace Tensorflow.Keras.ArgsDefinition
2+
{
3+
public class UpSampling2DArgs : LayerArgs
4+
{
5+
public TensorShape Size { get; set; }
6+
public string DataFormat { get; set; }
7+
/// <summary>
8+
/// 'nearest', 'bilinear'
9+
/// </summary>
10+
public string Interpolation { get; set; } = "nearest";
11+
}
12+
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/ZeroPadding2DArgs.cs renamed to src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs

File renamed without changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using NumSharp;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow.Keras.ArgsDefinition;
6+
7+
namespace Tensorflow.Keras.Layers
8+
{
9+
public partial class LayersApi
10+
{
11+
/// <summary>
12+
/// Zero-padding layer for 2D input (e.g. picture).
13+
/// </summary>
14+
/// <param name="padding"></param>
15+
/// <returns></returns>
16+
public ZeroPadding2D ZeroPadding2D(NDArray padding)
17+
=> new ZeroPadding2D(new ZeroPadding2DArgs
18+
{
19+
Padding = padding
20+
});
21+
22+
/// <summary>
23+
/// Upsampling layer for 2D inputs.<br/>
24+
/// Repeats the rows and columns of the data by size[0] and size[1] respectively.
25+
/// </summary>
26+
/// <param name="size"></param>
27+
/// <param name="data_format"></param>
28+
/// <param name="interpolation"></param>
29+
/// <returns></returns>
30+
public UpSampling2D UpSampling2D(TensorShape size = null,
31+
string data_format = null,
32+
string interpolation = "nearest")
33+
=> new UpSampling2D(new UpSampling2DArgs
34+
{
35+
Size = size ?? (2, 2)
36+
});
37+
}
38+
}

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace Tensorflow.Keras.Layers
88
{
9-
public class LayersApi
9+
public partial class LayersApi
1010
{
1111
/// <summary>
1212
/// Functional interface for the batch normalization layer.
@@ -372,19 +372,8 @@ public Rescaling Rescaling(float scale,
372372
InputShape = input_shape
373373
});
374374

375-
/// <summary>
376-
/// Zero-padding layer for 2D input (e.g. picture).
377-
/// </summary>
378-
/// <param name="padding"></param>
379-
/// <returns></returns>
380-
public ZeroPadding2D ZeroPadding2D(NDArray padding)
381-
=> new ZeroPadding2D(new ZeroPadding2DArgs
382-
{
383-
Padding = padding
384-
});
385-
386-
public Tensor add(params Tensor[] inputs)
387-
=> new Add(new MergeArgs { Inputs = inputs }).Apply(inputs);
375+
public Add Add(params Tensor[] inputs)
376+
=> new Add(new MergeArgs { Inputs = inputs });
388377

389378
public GlobalAveragePooling2D GlobalAveragePooling2D()
390379
=> new GlobalAveragePooling2D(new Pooling2DArgs { });
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Engine;
6+
using Tensorflow.Keras.Utils;
7+
using static Tensorflow.Binding;
8+
using static Tensorflow.KerasApi;
9+
10+
namespace Tensorflow.Keras.Layers
11+
{
12+
public class UpSampling2D : Layer
13+
{
14+
UpSampling2DArgs args;
15+
int[] size;
16+
string data_format;
17+
string interpolation => args.Interpolation;
18+
19+
public UpSampling2D(UpSampling2DArgs args) : base(args)
20+
{
21+
this.args = args;
22+
data_format = conv_utils.normalize_data_format(args.DataFormat);
23+
size = conv_utils.normalize_tuple(args.Size, 2, "size");
24+
inputSpec = new InputSpec(ndim: 4);
25+
}
26+
27+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
28+
{
29+
return keras.backend.resize_images(inputs,
30+
size[0], size[1],
31+
data_format,
32+
interpolation: interpolation);
33+
}
34+
}
35+
}
File renamed without changes.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using NumSharp;
3+
using Tensorflow;
4+
using static Tensorflow.KerasApi;
5+
6+
namespace TensorFlowNET.UnitTest.Keras
7+
{
8+
[TestClass]
9+
public class LayersReshapingTest : EagerModeTestBase
10+
{
11+
[TestMethod]
12+
public void ZeroPadding2D()
13+
{
14+
var input_shape = new[] { 1, 1, 2, 2 };
15+
var x = np.arange(np.prod(input_shape)).reshape(input_shape);
16+
var zero_padding_2d = keras.layers.ZeroPadding2D(new[,] { { 1, 0 }, { 1, 0 } });
17+
var y = zero_padding_2d.Apply(x);
18+
Assert.AreEqual((1, 2, 3, 2), y.shape);
19+
}
20+
21+
[TestMethod]
22+
public void UpSampling2D()
23+
{
24+
var input_shape = new[] { 2, 2, 1, 3 };
25+
var x = np.arange(np.prod(input_shape)).reshape(input_shape);
26+
var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x);
27+
Assert.AreEqual((2, 2, 2, 3), y.shape);
28+
}
29+
}
30+
}

test/TensorFlowNET.UnitTest/Keras/LayersTest.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using NumSharp;
3+
using Tensorflow;
34
using static Tensorflow.KerasApi;
45

56
namespace TensorFlowNET.UnitTest.Keras

0 commit comments

Comments
 (0)