Skip to content

Commit d452d8c

Browse files
authored
Merge pull request #1144 from Wanglongzhi2001/master
fix: fix the bug of load LSTM model and add test
2 parents 3006c86 + 68772b2 commit d452d8c

34 files changed

+81
-40
lines changed

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Collections.Generic;
44
using System.Text;
55

6-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
6+
namespace Tensorflow.Keras.ArgsDefinition
77
{
88
public class GRUCellArgs : AutoSerializeLayerArgs
99
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
1+
namespace Tensorflow.Keras.ArgsDefinition
22
{
33
public class LSTMArgs : RNNArgs
44
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Newtonsoft.Json;
22
using static Tensorflow.Binding;
33

4-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
4+
namespace Tensorflow.Keras.ArgsDefinition
55
{
66
// TODO: complete the implementation
77
public class LSTMCellArgs : AutoSerializeLayerArgs

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using Newtonsoft.Json;
22
using System.Collections.Generic;
3-
using Tensorflow.Keras.Layers.Rnn;
3+
using Tensorflow.Keras.Layers;
44

5-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
5+
namespace Tensorflow.Keras.ArgsDefinition
66
{
77
// TODO(Rinne): add regularizers.
88
public class RNNArgs : AutoSerializeLayerArgs
@@ -23,16 +23,22 @@ public class RNNArgs : AutoSerializeLayerArgs
2323
public int? InputDim { get; set; }
2424
public int? InputLength { get; set; }
2525
// TODO: Add `num_constants` and `zero_output_for_mask`.
26-
26+
[JsonProperty("units")]
2727
public int Units { get; set; }
28+
[JsonProperty("activation")]
2829
public Activation Activation { get; set; }
30+
[JsonProperty("recurrent_activation")]
2931
public Activation RecurrentActivation { get; set; }
32+
[JsonProperty("use_bias")]
3033
public bool UseBias { get; set; } = true;
3134
public IInitializer KernelInitializer { get; set; }
3235
public IInitializer RecurrentInitializer { get; set; }
3336
public IInitializer BiasInitializer { get; set; }
37+
[JsonProperty("dropout")]
3438
public float Dropout { get; set; } = .0f;
39+
[JsonProperty("zero_output_for_mask")]
3540
public bool ZeroOutputForMask { get; set; } = false;
41+
[JsonProperty("recurrent_dropout")]
3642
public float RecurrentDropout { get; set; } = .0f;
3743
}
3844
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Text;
44
using Tensorflow.Common.Types;
55

6-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
6+
namespace Tensorflow.Keras.ArgsDefinition
77
{
88
public class RnnOptionalArgs: IOptionalArgs
99
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
1+
namespace Tensorflow.Keras.ArgsDefinition
22
{
33
public class SimpleRNNArgs : RNNArgs
44
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Newtonsoft.Json;
22

3-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
3+
namespace Tensorflow.Keras.ArgsDefinition
44
{
55
public class SimpleRNNCellArgs: AutoSerializeLayerArgs
66
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using System.Collections.Generic;
2-
using Tensorflow.Keras.Layers.Rnn;
2+
using Tensorflow.Keras.Layers;
33

4-
namespace Tensorflow.Keras.ArgsDefinition.Rnn
4+
namespace Tensorflow.Keras.ArgsDefinition
55
{
66
public class StackedRNNCellsArgs : LayerArgs
77
{

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using System;
22
using Tensorflow.Framework.Models;
33
using Tensorflow.Keras.Engine;
4-
using Tensorflow.Keras.Layers.Rnn;
4+
using Tensorflow.Keras.Layers;
55
using Tensorflow.NumPy;
66
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
77

src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Text;
44
using Tensorflow.Common.Types;
55

6-
namespace Tensorflow.Keras.Layers.Rnn
6+
namespace Tensorflow.Keras.Layers
77
{
88
public interface IRnnCell: ILayer
99
{

0 commit comments

Comments
 (0)