Skip to content

Commit 59106f8

Browse files
committed
Add Merge and Add layer.
1 parent 67a661a commit 59106f8

File tree

8 files changed

+232
-84
lines changed

8 files changed

+232
-84
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class MergeArgs : LayerArgs
8+
{
9+
public Tensors Inputs { get; set; }
10+
}
11+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
6+
namespace Tensorflow.Keras.Layers
7+
{
8+
public class Add : Merge
9+
{
10+
public Add(MergeArgs args) : base(args)
11+
{
12+
13+
}
14+
}
15+
}

src/TensorFlowNET.Keras/Layers/BatchNormalization.cs

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,14 @@ protected override void build(TensorShape input_shape)
119119
built = true;
120120
}
121121

122-
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
122+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false)
123123
{
124124
Tensor outputs = null;
125-
125+
var training_tensor = tf.logical_and(training, Trainable);
126126
if (fused)
127127
{
128-
Tensor training = tf.convert_to_tensor(is_training);
129-
outputs = _fused_batch_norm(inputs, training: training);
128+
// var training = tf.convert_to_tensor(training);
129+
outputs = _fused_batch_norm(inputs, training: training_tensor);
130130
return outputs;
131131
}
132132

@@ -150,20 +150,21 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
150150
inputs,
151151
gamma,
152152
beta,
153-
epsilon: epsilon,
154-
data_format: _data_format);
153+
mean: moving_mean,
154+
variance: moving_variance,
155+
epsilon: epsilon, is_training: true,
156+
data_format: _data_format,
157+
exponential_avg_factor: exponential_avg_factor);
155158
};
156159

157160
Func<Tensor[]> _fused_batch_norm_inference = () =>
158161
{
159-
var moving_mean_tensor = moving_mean.AsTensor();
160-
var moving_variance_tensor = moving_variance.AsTensor();
161162
return tf.nn.fused_batch_norm(
162163
inputs,
163164
gamma,
164165
beta,
165-
mean: moving_mean_tensor,
166-
variance: moving_variance_tensor,
166+
mean: moving_mean,
167+
variance: moving_variance,
167168
epsilon: epsilon,
168169
is_training: false,
169170
data_format: _data_format);
@@ -176,35 +177,54 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
176177
var (output, mean, variance) = (results[0], results[1], results[2]);
177178
var training_value = tf_utils.constant_value(training);
178179

179-
Tensor momentum_tensor;
180-
if (training_value == null)
180+
if (!training_value.HasValue || (training_value.HasValue && training_value.Value))
181181
{
182-
momentum_tensor = tf_utils.smart_cond(training,
183-
() => new float[] { momentum }, () => new float[] { 1.0f })[0];
184-
}
185-
else
186-
{
187-
momentum_tensor = ops.convert_to_tensor(momentum);
188-
}
182+
Tensor momentum_tensor = null;
183+
if (!use_fused_avg_updates)
184+
{
185+
if (training_value == null)
186+
momentum_tensor = tf_utils.smart_cond(training,
187+
() => new float[] { momentum },
188+
() => new float[] { 1.0f })[0];
189+
else
190+
momentum_tensor = ops.convert_to_tensor(momentum);
191+
}
192+
193+
if (use_fused_avg_updates)
194+
_assign_new_value(moving_mean, mean);
195+
else
196+
_assign_moving_average(moving_variance, variance, momentum_tensor);
189197

190-
if (training_value == null)
191-
{
192-
var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor);
193-
var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor);
194-
add_update(new Tensor[] { mean_update }, inputs: true);
195-
add_update(new Tensor[] { variance_update }, inputs: true);
198+
if (use_fused_avg_updates)
199+
_assign_new_value(moving_variance, mean);
200+
else
201+
_assign_moving_average(moving_variance, variance, momentum_tensor);
202+
203+
// var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor);
204+
// var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor);
205+
// add_update(new Tensor[] { mean_update }, inputs: true);
206+
// add_update(new Tensor[] { variance_update }, inputs: true);
196207
}
197208

198209
return output;
199210
}
200211

201-
public Tensor _assign_moving_average(RefVariable variable, Tensor value, Tensor momentum)
212+
Tensor _assign_new_value(IVariableV1 variable, Tensor value)
213+
{
214+
return tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope =>
215+
{
216+
// var cm = ops.colocate_with(variable);
217+
return state_ops.assign_sub(variable, value, name: scope);
218+
});
219+
}
220+
221+
Tensor _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum)
202222
{
203-
return tf_with(ops.name_scope(null, "AssignMovingAvg", new { variable, value, momentum }), scope =>
223+
return tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope =>
204224
{
205225
// var cm = ops.colocate_with(variable);
206226
var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay");
207-
var update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay;
227+
var update_delta = (variable.AsTensor() - math_ops.cast(value, variable.dtype)) * decay;
208228
return state_ops.assign_sub(variable, update_delta, name: scope);
209229
});
210230
}

src/TensorFlowNET.Keras/Layers/Convolutional.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
using Tensorflow.Keras.Engine;
2121
using Tensorflow.Keras.Utils;
2222
using Tensorflow.Operations;
23+
using static Tensorflow.Binding;
2324

2425
namespace Tensorflow.Keras.Layers
2526
{
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
6+
namespace Tensorflow.Keras.Layers
7+
{
8+
public class GlobalAveragePooling2D : GlobalPooling2D
9+
{
10+
public GlobalAveragePooling2D(Pooling2DArgs args)
11+
: base(args)
12+
{
13+
}
14+
15+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
16+
{
17+
if (data_format == "channels_last")
18+
return math_ops.reduce_mean(inputs, new int[] { 1, 2 }, false);
19+
else
20+
return math_ops.reduce_mean(inputs, new int[] { 2, 3 }, false);
21+
}
22+
}
23+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
using Tensorflow.Keras.Engine;
6+
using Tensorflow.Keras.Utils;
7+
8+
namespace Tensorflow.Keras.Layers
9+
{
10+
public abstract class GlobalPooling2D : Layer
11+
{
12+
Pooling2DArgs args;
13+
protected string data_format => args.DataFormat;
14+
protected InputSpec input_spec;
15+
16+
public GlobalPooling2D(Pooling2DArgs args) : base(args)
17+
{
18+
this.args = args;
19+
args.DataFormat = conv_utils.normalize_data_format(data_format);
20+
input_spec = new InputSpec(ndim: 4);
21+
}
22+
}
23+
}

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 76 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ public class LayersApi
2828
/// <param name="renorm"></param>
2929
/// <param name="renorm_momentum"></param>
3030
/// <returns></returns>
31-
public Tensors batch_normalization(Tensor inputs,
32-
int axis = -1,
31+
public BatchNormalization BatchNormalization(int axis = -1,
3332
float momentum = 0.99f,
3433
float epsilon = 0.001f,
3534
bool center = true,
@@ -38,31 +37,26 @@ public Tensors batch_normalization(Tensor inputs,
3837
IInitializer gamma_initializer = null,
3938
IInitializer moving_mean_initializer = null,
4039
IInitializer moving_variance_initializer = null,
41-
Tensor training = null,
4240
bool trainable = true,
4341
string name = null,
4442
bool renorm = false,
4543
float renorm_momentum = 0.99f)
46-
{
47-
var layer = new BatchNormalization(new BatchNormalizationArgs
48-
{
49-
Axis = axis,
50-
Momentum = momentum,
51-
Epsilon = epsilon,
52-
Center = center,
53-
Scale = scale,
54-
BetaInitializer = beta_initializer,
55-
GammaInitializer = gamma_initializer,
56-
MovingMeanInitializer = moving_mean_initializer,
57-
MovingVarianceInitializer = moving_variance_initializer,
58-
Renorm = renorm,
59-
RenormMomentum = renorm_momentum,
60-
Trainable = trainable,
61-
Name = name
62-
});
63-
64-
return layer.Apply(inputs);
65-
}
44+
=> new BatchNormalization(new BatchNormalizationArgs
45+
{
46+
Axis = axis,
47+
Momentum = momentum,
48+
Epsilon = epsilon,
49+
Center = center,
50+
Scale = scale,
51+
BetaInitializer = beta_initializer ?? tf.zeros_initializer,
52+
GammaInitializer = gamma_initializer ?? tf.ones_initializer,
53+
MovingMeanInitializer = moving_mean_initializer ?? tf.zeros_initializer,
54+
MovingVarianceInitializer = moving_variance_initializer ?? tf.ones_initializer,
55+
Renorm = renorm,
56+
RenormMomentum = renorm_momentum,
57+
Trainable = trainable,
58+
Name = name
59+
});
6660

6761
/// <summary>
6862
///
@@ -115,53 +109,64 @@ public Conv2D Conv2D(int filters,
115109
Activation = activation ?? keras.activations.Linear
116110
});
117111

118-
public Tensor conv2d(Tensor inputs,
119-
int filters,
120-
int[] kernel_size,
121-
int[] strides = null,
112+
public Conv2D Conv2D(int filters,
113+
TensorShape kernel_size = null,
114+
TensorShape strides = null,
122115
string padding = "valid",
123-
string data_format = "channels_last",
124-
int[] dilation_rate = null,
116+
string data_format = null,
117+
TensorShape dilation_rate = null,
118+
int groups = 1,
119+
string activation = null,
125120
bool use_bias = true,
121+
string kernel_initializer = "glorot_uniform",
122+
string bias_initializer = "zeros",
123+
string kernel_regularizer = null,
124+
string bias_regularizer = null,
125+
string activity_regularizer = null)
126+
=> new Conv2D(new Conv2DArgs
127+
{
128+
Rank = 2,
129+
Filters = filters,
130+
KernelSize = kernel_size,
131+
Strides = strides == null ? (1, 1) : strides,
132+
Padding = padding,
133+
DataFormat = data_format,
134+
DilationRate = dilation_rate == null ? (1, 1) : dilation_rate,
135+
Groups = groups,
136+
UseBias = use_bias,
137+
KernelInitializer = GetInitializerByName(kernel_initializer),
138+
BiasInitializer = GetInitializerByName(bias_initializer),
139+
Activation = GetActivationByName(activation)
140+
});
141+
142+
public Dense Dense(int units,
126143
Activation activation = null,
127144
IInitializer kernel_initializer = null,
128145
IInitializer bias_initializer = null,
129-
bool trainable = true,
130-
string name = null)
131-
{
132-
if (strides == null)
133-
strides = new int[] { 1, 1 };
134-
if (dilation_rate == null)
135-
dilation_rate = new int[] { 1, 1 };
136-
if (bias_initializer == null)
137-
bias_initializer = tf.zeros_initializer;
138-
139-
var layer = new Conv2D(new Conv2DArgs
146+
TensorShape input_shape = null)
147+
=> new Dense(new DenseArgs
140148
{
141-
Filters = filters,
142-
KernelSize = kernel_size,
143-
Strides = strides,
144-
Padding = padding,
145-
DataFormat = data_format,
146-
DilationRate = dilation_rate,
147-
Activation = activation,
148-
UseBias = use_bias,
149-
KernelInitializer = kernel_initializer,
150-
BiasInitializer = bias_initializer,
151-
Trainable = trainable,
152-
Name = name
149+
Units = units,
150+
Activation = activation ?? keras.activations.Linear,
151+
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer,
152+
BiasInitializer = bias_initializer ?? tf.zeros_initializer,
153+
InputShape = input_shape
153154
});
154155

155-
return layer.Apply(inputs);
156-
}
156+
public Dense Dense(int units)
157+
=> new Dense(new DenseArgs
158+
{
159+
Units = units,
160+
Activation = GetActivationByName("linear")
161+
});
157162

158163
public Dense Dense(int units,
159-
Activation activation = null,
164+
string activation = null,
160165
TensorShape input_shape = null)
161166
=> new Dense(new DenseArgs
162167
{
163168
Units = units,
164-
Activation = activation ?? keras.activations.Linear,
169+
Activation = GetActivationByName(activation),
165170
InputShape = input_shape
166171
});
167172

@@ -367,6 +372,12 @@ public ZeroPadding2D ZeroPadding2D(NDArray padding)
367372
Padding = padding
368373
});
369374

375+
public Tensor add(params Tensor[] inputs)
376+
=> new Add(new MergeArgs { Inputs = inputs }).Apply(inputs);
377+
378+
public GlobalAveragePooling2D GlobalAveragePooling2D()
379+
=> new GlobalAveragePooling2D(new Pooling2DArgs { });
380+
370381
Activation GetActivationByName(string name)
371382
=> name switch
372383
{
@@ -376,5 +387,14 @@ Activation GetActivationByName(string name)
376387
"tanh" => keras.activations.Tanh,
377388
_ => keras.activations.Linear
378389
};
390+
391+
IInitializer GetInitializerByName(string name)
392+
=> name switch
393+
{
394+
"glorot_uniform" => tf.glorot_uniform_initializer,
395+
"zeros" => tf.zeros_initializer,
396+
"ones" => tf.ones_initializer,
397+
_ => tf.glorot_uniform_initializer
398+
};
379399
}
380400
}

0 commit comments

Comments
 (0)