Skip to content

Commit 8b9fca4

Browse files
committed
Add save and restore model from config.
1 parent 91399b1 commit 8b9fca4

File tree

18 files changed

+269
-26
lines changed

18 files changed

+269
-26
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ public static void add<T>(this IList<T> list, IEnumerable<T> elements)
5858
public static void append<T>(this IList<T> list, T element)
5959
=> list.Insert(list.Count, element);
6060

61+
public static void append<T>(this IList<T> list, IList<T> elements)
62+
{
63+
for (int i = 0; i < elements.Count(); i++)
64+
list.Insert(list.Count, elements[i]);
65+
}
66+
6167
public static T[] concat<T>(this IList<T> list1, IList<T> list2)
6268
{
6369
var list = new List<T>();

src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public bool RecordGradient(string op_name,
3838
}*/
3939
}
4040

41-
Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
41+
// Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
4242
if (!should_record) return should_record;
4343

4444
Tensor[] op_outputs;

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -761,21 +761,19 @@ private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Ten
761761
{
762762
sx = array_ops.shape(x);
763763
sy = array_ops.shape(y);
764-
765-
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
766-
return new[]
767-
{
768-
(sx, rx, true),
769-
(sy, ry, true)
770-
};
771764
}
772765
else
773766
{
774767
sx = array_ops.shape_internal(x, optimize: false);
775768
sy = array_ops.shape_internal(y, optimize: false);
776769
}
777770

778-
throw new NotImplementedException("");
771+
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
772+
return new[]
773+
{
774+
(sx, rx, true),
775+
(sy, ry, true)
776+
};
779777
}
780778
}
781779
}

src/TensorFlowNET.Core/Keras/Engine/INode.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ public interface INode
1313
INode[] ParentNodes { get; }
1414
IEnumerable<(ILayer, int, int, Tensor)> iterate_inbound();
1515
bool is_input { get; }
16-
NodeConfig serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map);
16+
List<NodeConfig> serialize(Func<string, int, string> make_node_key, Dictionary<string, int> node_conversion_map);
1717
}
1818
}

src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ public class KerasHistory
88
ILayer layer;
99
public ILayer Layer => layer;
1010
int node_index;
11+
public int NodeIndex => node_index;
1112
int tensor_index;
13+
public int TensorIndex => tensor_index;
1214
Tensor tensor;
1315

1416
public KerasHistory(ILayer layer, int node_index, int tensor_index, Tensor tensor)

src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ public class LayerConfig
1111
public string Name { get; set; }
1212
public string ClassName { get; set; }
1313
public LayerArgs Config { get; set; }
14-
public List<INode> InboundNodes { get; set; }
14+
public List<NodeConfig> InboundNodes { get; set; }
1515
}
1616
}

src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ public class ModelConfig
99
{
1010
public string Name { get; set; }
1111
public List<LayerConfig> Layers { get; set; }
12-
public List<ILayer> InputLayers { get; set; }
13-
public List<ILayer> OutputLayers { get; set; }
12+
public List<NodeConfig> InputLayers { get; set; }
13+
public List<NodeConfig> OutputLayers { get; set; }
14+
15+
public override string ToString()
16+
=> $"{Name}, {Layers.Count} Layers, {InputLayers.Count} Input Layers, {OutputLayers.Count} Output Layers";
1417
}
1518
}

src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,8 @@ public class NodeConfig
99
public string Name { get; set; }
1010
public int NodeIndex { get; set; }
1111
public int TensorIndex { get; set; }
12+
13+
public override string ToString()
14+
=> $"{Name}, {NodeIndex}, {TensorIndex}";
1215
}
1316
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Keras.Layers;
6+
using Tensorflow.Keras.Saving;
7+
using Tensorflow.Keras.Utils;
8+
using static Tensorflow.Binding;
9+
10+
namespace Tensorflow.Keras.Engine
11+
{
12+
public partial class Functional
13+
{
14+
/// <summary>
15+
/// Adds layers that are not connected to the outputs to the model.
16+
/// </summary>
17+
/// <param name="created_layers"></param>
18+
public void connect_ancillary_layers(Dictionary<string, ILayer> created_layers)
19+
{
20+
21+
}
22+
}
23+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Keras.Layers;
6+
using Tensorflow.Keras.Saving;
7+
using Tensorflow.Keras.Utils;
8+
using static Tensorflow.Binding;
9+
10+
namespace Tensorflow.Keras.Engine
11+
{
12+
public partial class Functional
13+
{
14+
public static Functional from_config(ModelConfig config)
15+
{
16+
var (input_tensors, output_tensors, created_layers) = reconstruct_from_config(config);
17+
var model = new Functional(input_tensors, output_tensors, name: config.Name);
18+
model.connect_ancillary_layers(created_layers);
19+
return model;
20+
}
21+
22+
/// <summary>
23+
/// Reconstructs graph from config object.
24+
/// </summary>
25+
/// <param name="config"></param>
26+
/// <returns></returns>
27+
static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config)
28+
{
29+
// Layer instances created during the graph reconstruction process.
30+
var created_layers = new Dictionary<string, ILayer>();
31+
var node_index_map = new Dictionary<(string, int), int>();
32+
var node_count_by_layer = new Dictionary<ILayer, int>();
33+
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
34+
// First, we create all layers and enqueue nodes to be processed
35+
foreach (var layer_data in config.Layers)
36+
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
37+
38+
// Then we process nodes in order of layer depth.
39+
// Nodes that cannot yet be processed (if the inbound node
40+
// does not yet exist) are re-enqueued, and the process
41+
// is repeated until all nodes are processed.
42+
while (unprocessed_nodes.Count > 0)
43+
{
44+
foreach(var layer_data in config.Layers)
45+
{
46+
var layer = created_layers[layer_data.Name];
47+
if (unprocessed_nodes.ContainsKey(layer))
48+
{
49+
var node_data = unprocessed_nodes[layer];
50+
// foreach (var node_data in unprocessed_nodes[layer])
51+
{
52+
process_node(layer, node_data, created_layers, node_count_by_layer, node_index_map);
53+
unprocessed_nodes.Remove(layer);
54+
}
55+
}
56+
}
57+
}
58+
59+
var input_tensors = new List<Tensor>();
60+
foreach (var layer_data in config.InputLayers)
61+
{
62+
var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex);
63+
var layer = created_layers[layer_name];
64+
var layer_output_tensors = layer.InboundNodes[node_index].Outputs;
65+
input_tensors.append(layer_output_tensors[tensor_index]);
66+
}
67+
68+
var output_tensors = new List<Tensor>();
69+
foreach (var layer_data in config.OutputLayers)
70+
{
71+
var (layer_name, node_index, tensor_index) = (layer_data.Name, layer_data.NodeIndex, layer_data.TensorIndex);
72+
var layer = created_layers[layer_name];
73+
var layer_output_tensors = layer.InboundNodes[node_index].Outputs;
74+
output_tensors.append(layer_output_tensors[tensor_index]);
75+
}
76+
77+
return (input_tensors, output_tensors, created_layers);
78+
}
79+
80+
static void process_layer(Dictionary<string, ILayer> created_layers,
81+
LayerConfig layer_data,
82+
Dictionary<ILayer, NodeConfig> unprocessed_nodes,
83+
Dictionary<ILayer, int> node_count_by_layer)
84+
{
85+
ILayer layer = null;
86+
var layer_name = layer_data.Name;
87+
if (created_layers.ContainsKey(layer_name))
88+
layer = created_layers[layer_name];
89+
else
90+
{
91+
layer = layer_data.ClassName switch
92+
{
93+
"InputLayer" => InputLayer.from_config(layer_data.Config),
94+
"Dense" => Dense.from_config(layer_data.Config),
95+
_ => throw new NotImplementedException("")
96+
};
97+
98+
created_layers[layer_name] = layer;
99+
}
100+
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
101+
102+
var inbound_nodes_data = layer_data.InboundNodes;
103+
foreach (var node_data in inbound_nodes_data)
104+
{
105+
if (!unprocessed_nodes.ContainsKey(layer))
106+
unprocessed_nodes[layer] = node_data;
107+
else
108+
unprocessed_nodes.Add(layer, node_data);
109+
}
110+
}
111+
112+
static void process_node(ILayer layer,
113+
NodeConfig node_data,
114+
Dictionary<string, ILayer> created_layers,
115+
Dictionary<ILayer, int> node_count_by_layer,
116+
Dictionary<(string, int), int> node_index_map)
117+
{
118+
var input_tensors = new List<Tensor>();
119+
var inbound_layer_name = node_data.Name;
120+
var inbound_node_index = node_data.NodeIndex;
121+
var inbound_tensor_index = node_data.TensorIndex;
122+
123+
var inbound_layer = created_layers[inbound_layer_name];
124+
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
125+
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
126+
127+
var output_tensors = layer.Apply(input_tensors);
128+
129+
// Update node index map.
130+
var output_index = output_tensors[0].KerasHistory.NodeIndex;
131+
node_index_map[(layer.Name, node_count_by_layer[layer])] = output_index;
132+
node_count_by_layer[layer] += 1;
133+
}
134+
135+
static bool _should_skip_first_node(ILayer layer)
136+
{
137+
return layer is Functional && layer.Layers[0] is InputLayer;
138+
}
139+
}
140+
}

0 commit comments

Comments
 (0)