Skip to content

Commit 905c4a9

Browse files
committed
Add keras.layers.Reshape.
1 parent db94e67 commit 905c4a9

File tree

5 files changed

+64
-3
lines changed

5 files changed

+64
-3
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace Tensorflow.Keras.ArgsDefinition
2+
{
3+
public class ReshapeArgs : LayerArgs
4+
{
5+
public TensorShape TargetShape { get; set; }
6+
}
7+
}

src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,16 @@ public UpSampling2D UpSampling2D(TensorShape size = null,
3434
{
3535
Size = size ?? (2, 2)
3636
});
37+
38+
/// <summary>
39+
/// Layer that reshapes inputs into the given shape.
40+
/// </summary>
41+
/// <param name="target_shape"></param>
42+
/// <returns></returns>
43+
public Reshape Reshape(TensorShape target_shape)
44+
=> new Reshape(new ReshapeArgs
45+
{
46+
TargetShape = target_shape
47+
});
3748
}
3849
}

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,8 @@ public Rescaling Rescaling(float scale,
372372
InputShape = input_shape
373373
});
374374

375-
public Add Add(params Tensor[] inputs)
376-
=> new Add(new MergeArgs { Inputs = inputs });
375+
public Add Add()
376+
=> new Add(new MergeArgs { });
377377

378378
public GlobalAveragePooling2D GlobalAveragePooling2D()
379379
=> new GlobalAveragePooling2D(new Pooling2DArgs { });
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using Tensorflow.Keras.ArgsDefinition;
2+
using Tensorflow.Keras.Engine;
3+
using static Tensorflow.KerasApi;
4+
using static Tensorflow.Binding;
5+
using System.Collections.Generic;
6+
using System;
7+
8+
namespace Tensorflow.Keras.Layers
9+
{
10+
/// <summary>
11+
/// Layer that reshapes inputs into the given shape.
12+
/// </summary>
13+
public class Reshape : Layer
14+
{
15+
ReshapeArgs args;
16+
public Reshape(ReshapeArgs args)
17+
: base(args)
18+
{
19+
this.args = args;
20+
}
21+
22+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
23+
{
24+
var shape = new List<int> { inputs.shape[0] };
25+
shape.AddRange(args.TargetShape.dims);
26+
27+
var result = array_ops.reshape(inputs, shape.ToArray());
28+
if (!tf.Context.executing_eagerly())
29+
// result = result.set_shape(compute_output_shape(inputs.shape));
30+
throw new NotImplementedException("");
31+
return result;
32+
}
33+
}
34+
}

test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using NumSharp;
3-
using Tensorflow;
3+
using static Tensorflow.Binding;
44
using static Tensorflow.KerasApi;
55

66
namespace TensorFlowNET.UnitTest.Keras
@@ -26,5 +26,14 @@ public void UpSampling2D()
2626
var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x);
2727
Assert.AreEqual((2, 2, 2, 3), y.shape);
2828
}
29+
30+
[TestMethod]
31+
public void Reshape()
32+
{
33+
var inputs = tf.zeros((10, 5, 20));
34+
var outputs = keras.layers.LeakyReLU().Apply(inputs);
35+
outputs = keras.layers.Reshape((20, 5)).Apply(outputs);
36+
Assert.AreEqual((10, 20, 5), outputs.shape);
37+
}
2938
}
3039
}

0 commit comments

Comments
 (0)