Skip to content

Commit b58c241

Browse files
committed
Add keras.Concatenate.
1 parent 94ad2f5 commit b58c241

File tree

12 files changed

+122
-10
lines changed

12 files changed

+122
-10
lines changed

src/TensorFlowNET.Core/Keras/ArgsDefinition/MergeArgs.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition
77
public class MergeArgs : LayerArgs
88
{
99
public Tensors Inputs { get; set; }
10+
public int Axis { get; set; }
1011
}
1112
}

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ bool hasattr(Graph property, string attr)
407407

408408
var ret = tensor.TensorShape.unknown_shape(shape.dims[0]);
409409
var value = constant_value(tensor);
410-
if (value != null)
410+
if (!(value is null))
411411
{
412412
int[] d_ = { };
413413
foreach (int d in value)
@@ -418,7 +418,6 @@ bool hasattr(Graph property, string attr)
418418
d_[d_.Length] = -1; // None
419419
}
420420
ret = ret.merge_with(new TensorShape(d_));
421-
422421
}
423422
return ret;
424423
}

src/TensorFlowNET.Keras/BackendImpl.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,5 +226,25 @@ public Tensor resize_images(Tensor x, int height_factor, int width_factor,
226226
x.set_shape(output_shape);
227227
return x;
228228
}
229+
230+
/// <summary>
231+
/// Concatenates a list of tensors alongside the specified axis.
232+
/// </summary>
233+
/// <param name="tensors">list of tensors to concatenate.</param>
234+
/// <param name="axis">concatenation axis.</param>
235+
/// <returns></returns>
236+
public Tensor concatenate(Tensors tensors, int axis = -1)
237+
{
238+
if(axis < 0)
239+
{
240+
var rank = tensors[0].NDims;
241+
if (rank > -1)
242+
axis %= rank;
243+
else
244+
axis = 0;
245+
}
246+
247+
return array_ops.concat(tensors, axis);
248+
}
229249
}
230250
}

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,13 @@ protected void MaybeBuild(Tensors inputs)
177177
tf.init_scope();
178178

179179
tf.Context.eager_mode();
180-
build(inputs.shape);
180+
build(inputs);
181181
tf.Context.restore_mode();
182182

183183
built = true;
184184
}
185185

186-
protected virtual void build(TensorShape input_shape)
186+
protected virtual void build(Tensors inputs)
187187
{
188188
built = true;
189189
}

src/TensorFlowNET.Keras/Layers/BatchNormalization.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ public BatchNormalization(BatchNormalizationArgs args) : base(args)
5252
axis = args.Axis.dims;
5353
}
5454

55-
protected override void build(TensorShape input_shape)
55+
protected override void build(Tensors inputs)
5656
{
57+
TensorShape input_shape = inputs.shape;
5758
var ndims = input_shape.ndim;
5859
foreach (var (idx, x) in enumerate(axis))
5960
if (x < 0)

src/TensorFlowNET.Keras/Layers/Convolutional.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ public Convolutional(ConvolutionalArgs args) : base(args)
5656
_tf_data_format = conv_utils.convert_data_format(data_format, rank + 2);
5757
}
5858

59-
protected override void build(TensorShape input_shape)
59+
protected override void build(Tensors inputs)
6060
{
61+
TensorShape input_shape = inputs.shape;
6162
int channel_axis = data_format == "channels_first" ? 1 : -1;
6263
int input_channel = channel_axis < 0 ?
6364
input_shape.dims[input_shape.ndim + channel_axis] :

src/TensorFlowNET.Keras/Layers/Dense.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ public Dense(DenseArgs args) :
4141
this.inputSpec = new InputSpec(min_ndim: 2);
4242
}
4343

44-
protected override void build(TensorShape input_shape)
44+
protected override void build(Tensors inputs)
4545
{
46+
TensorShape input_shape = inputs.shape;
4647
var last_dim = input_shape.dims.Last();
4748
var axes = new Dictionary<int, int>();
4849
axes[-1] = last_dim;

src/TensorFlowNET.Keras/Layers/Embedding.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public Embedding(EmbeddingArgs args)
5252
SupportsMasking = mask_zero;
5353
}
5454

55-
protected override void build(TensorShape input_shape)
55+
protected override void build(Tensors inputs)
5656
{
5757
tf.Context.eager_mode();
5858
embeddings = add_weight(shape: (input_dim, output_dim),
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using NumSharp;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow.Keras.ArgsDefinition;
6+
7+
namespace Tensorflow.Keras.Layers
8+
{
9+
public partial class LayersApi
10+
{
11+
/// <summary>
12+
/// Layer that concatenates a list of inputs.
13+
/// </summary>
14+
/// <param name="axis">Axis along which to concatenate.</param>
15+
/// <returns></returns>
16+
public Concatenate Concatenate(int axis = -1)
17+
=> new Concatenate(new MergeArgs
18+
{
19+
Axis = axis
20+
});
21+
}
22+
}

src/TensorFlowNET.Keras/Layers/Merge.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public Merge(MergeArgs args) : base(args)
1414

1515
}
1616

17-
protected override void build(TensorShape input_shape)
17+
protected override void build(Tensors inputs)
1818
{
1919
// output_shape = input_shape.dims[1^];
2020
}
@@ -24,7 +24,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_tra
2424
return _merge_function(inputs);
2525
}
2626

27-
Tensors _merge_function(Tensors inputs)
27+
protected virtual Tensors _merge_function(Tensors inputs)
2828
{
2929
var output = inputs[0];
3030
foreach (var i in range(1, inputs.Length))

0 commit comments

Comments
 (0)