Skip to content

Commit fafed7d

Browse files
committed
Massive updates for TensorFlowOpLayer #652
1 parent ea3fc6a commit fafed7d

File tree

19 files changed

+423
-94
lines changed

19 files changed

+423
-94
lines changed

src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.Linq;
34
using Tensorflow.Framework.Models;
45
using Tensorflow.Graphs;
@@ -11,11 +12,34 @@ namespace Tensorflow.Functions
1112
/// </summary>
1213
public class ConcreteFunction : IDisposable
1314
{
14-
public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
1515
IntPtr _handle;
16+
FuncGraph func_graph;
17+
18+
public string Name
19+
{
20+
get
21+
{
22+
if (func_graph != null)
23+
return func_graph.FuncName;
24+
25+
return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
26+
}
27+
}
28+
1629
public Tensor[] Outputs;
30+
public Type ReturnType;
1731
public TensorSpec[] OutputStructure;
1832

33+
public ConcreteFunction(string name)
34+
{
35+
func_graph = new FuncGraph(name);
36+
}
37+
38+
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs)
39+
{
40+
func_graph = graph;
41+
}
42+
1943
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
2044
{
2145
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
@@ -28,8 +52,8 @@ public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
2852

2953
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
3054
_handle = graph.ToGraph(opers,
31-
new Operation[] { input },
32-
new Operation[] { output },
55+
new[] { input },
56+
new[] { output },
3357
null);
3458
}
3559
}
@@ -48,8 +72,8 @@ public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
4872

4973
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
5074
_handle = graph.ToGraph(opers,
51-
new Operation[] { input },
52-
new Operation[] { output.variant_tensor.op },
75+
new[] { input },
76+
new[] { output.variant_tensor },
5377
null);
5478
}
5579
}
@@ -72,12 +96,38 @@ public ConcreteFunction(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> func,
7296

7397
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
7498
_handle = graph.ToGraph(opers,
75-
new Operation[] { input1, input2, input3 },
76-
new Operation[] { outputs.Item1.op, outputs.Item2.op },
99+
new[] { input1, input2, input3 },
100+
new[] { outputs.Item1, outputs.Item2 },
77101
null);
78102
}
79103
}
80104

105+
public void ToGraph(Tensors inputs, Tensors outputs)
106+
{
107+
var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
108+
_handle = func_graph.ToGraph(opers,
109+
inputs,
110+
outputs,
111+
null);
112+
}
113+
114+
public Tensors Invoke(Tensors inputs)
115+
{
116+
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly());
117+
var (forward_function, args_with_tangents) = forward_backward.Forward();
118+
Tensors flat_outputs = null;
119+
if (tf.Context.executing_eagerly())
120+
flat_outputs = forward_function.Call(args_with_tangents);
121+
forward_backward.Record(flat_outputs);
122+
return flat_outputs;
123+
}
124+
125+
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
126+
{
127+
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
128+
return new ForwardBackwardCall(functions, args, tape_watching: true);
129+
}
130+
81131
public void Dispose()
82132
{
83133
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using Google.Protobuf;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using Tensorflow.Graphs;
7+
using static Tensorflow.Binding;
8+
9+
namespace Tensorflow.Functions
10+
{
11+
public class EagerDefinedFunction
12+
{
13+
public int _num_outputs;
14+
public string Name => _func_graph.FuncName;
15+
16+
FuncGraph _func_graph;
17+
public EagerDefinedFunction(string name, FuncGraph graph,
18+
Tensors inputs, Tensors outputs,
19+
Dictionary<string, string> attrs)
20+
{
21+
_num_outputs = outputs.Length;
22+
23+
var input_ops = inputs.Select(x => x.op).ToArray();
24+
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op))
25+
.Select(x => x as Operation).ToArray();
26+
var output_names = new string[0];
27+
28+
_func_graph = new FuncGraph(graph, name, attrs);
29+
_func_graph.ToGraph(operations, inputs, outputs, output_names);
30+
}
31+
32+
public Tensors Call(Tensors args)
33+
{
34+
var results = tf.Runner.TFE_Execute(tf.Context,
35+
tf.Context.DeviceName,
36+
_func_graph.FuncName,
37+
args,
38+
null,
39+
_num_outputs);
40+
41+
return results;
42+
}
43+
}
44+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Graphs;
5+
6+
namespace Tensorflow.Functions
7+
{
8+
public class FirstOrderTapeGradientFunctions : TapeGradientFunctions
9+
{
10+
public FirstOrderTapeGradientFunctions(FuncGraph func_graph,
11+
bool need_gradients_for_jvps) : base(func_graph,
12+
need_gradients_for_jvps)
13+
{
14+
15+
}
16+
17+
public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
18+
{
19+
var outputs = _func_graph.Outputs;
20+
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs)
21+
= BuildFunctionsForOutputs(outputs, inference_args);
22+
return _forward;
23+
}
24+
}
25+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Functions
6+
{
7+
/// <summary>
8+
/// Holds the state of a function call between execution and recording.
9+
/// </summary>
10+
public class ForwardBackwardCall
11+
{
12+
TapeGradientFunctions _functions;
13+
Tensors _inference_args;
14+
Tensors _input_tangents;
15+
bool _tape_watching;
16+
17+
public ForwardBackwardCall(TapeGradientFunctions functions,
18+
Tensors inference_args,
19+
bool tape_watching)
20+
{
21+
_functions = functions;
22+
_inference_args = inference_args;
23+
_tape_watching = tape_watching;
24+
}
25+
26+
public (EagerDefinedFunction, Tensors) Forward()
27+
{
28+
var forward_function = _functions.Forward(_inference_args);
29+
return (forward_function, _inference_args);
30+
}
31+
32+
public void Record(Tensors flat_outputs)
33+
{
34+
if (_tape_watching && flat_outputs != null)
35+
_functions.Record(flat_outputs, _inference_args);
36+
}
37+
}
38+
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Graphs;
5+
using static Tensorflow.Binding;
6+
using static Tensorflow.tensorflow;
7+
8+
namespace Tensorflow.Functions
9+
{
10+
/// <summary>
11+
/// Caches forward and backward functions compatible with eager gradients.
12+
/// </summary>
13+
public abstract class TapeGradientFunctions
14+
{
15+
string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
16+
string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
17+
string _FORWARD_PREFIX = "__forward_";
18+
string _BACKWARD_PREFIX = "__backward_";
19+
string _INFERENCE_PREFIX = "__inference_";
20+
21+
protected FuncGraph _func_graph;
22+
protected EagerDefinedFunction _forward;
23+
protected FuncGraph _forward_graph;
24+
protected List<int> _forwardprop_output_indices;
25+
protected int _num_forwardprop_outputs;
26+
protected ConcreteFunction _backward;
27+
28+
public TapeGradientFunctions(FuncGraph func_graph,
29+
bool need_gradients_for_jvps)
30+
{
31+
_func_graph = func_graph;
32+
}
33+
34+
public EagerDefinedFunction Forward(Tensors inference_args)
35+
{
36+
return ForwardAndBackwardFunctions(inference_args);
37+
}
38+
39+
/// <summary>
40+
/// Record the function call operation.
41+
/// </summary>
42+
/// <param name="flat_outputs"></param>
43+
/// <param name="inference_args"></param>
44+
public void Record(Tensors flat_outputs, Tensors inference_args)
45+
{
46+
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
47+
tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args,
48+
getBackwardFunction: () => backward_function);
49+
}
50+
51+
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs)
52+
{
53+
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
54+
{
55+
return new Tensor[0];
56+
57+
/*var gradients = ops.gradientFunctions[op_name](new EagerOperation
58+
{
59+
Name = op_name,
60+
NumInputs = op_inputs.Length,
61+
Inputs = op_inputs,
62+
NumOutputs = op_outputs.Length,
63+
Outputs = op_outputs,
64+
SkipInputIndices = unneeded_gradients,
65+
Attrs = attrs
66+
}, output_grads);
67+
68+
return gradients;*/
69+
};
70+
71+
return (_backward_function_wrapper, flat_outputs);
72+
}
73+
74+
protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
75+
BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args)
76+
{
77+
var trainable_outputs = new List<Tensor>();
78+
var trainable_indices = new List<int>();
79+
foreach(var (index, output) in enumerate(outputs))
80+
{
81+
if (gradients_util.IsTrainable(output))
82+
{
83+
trainable_outputs.Add(output);
84+
trainable_indices.Add(index);
85+
}
86+
}
87+
88+
var gradients_wrt_outputs = new List<Tensor>();
89+
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}");
90+
foreach (var output in trainable_outputs)
91+
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
92+
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
93+
_func_graph.Inputs,
94+
grad_ys: gradients_wrt_outputs.ToArray(),
95+
src_graph: _func_graph);
96+
97+
tf.Context.restore_mode();
98+
99+
var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}";
100+
var backward_function_attr = new Dictionary<string, string>();
101+
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
102+
backwards_graph.Inputs = gradients_wrt_outputs;
103+
backwards_graph.Outputs = gradients_wrt_inputs;
104+
105+
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);
106+
107+
var forward_function_attr = new Dictionary<string, string>();
108+
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
109+
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,
110+
_func_graph.Inputs, _func_graph.Outputs, forward_function_attr);
111+
112+
return (forward_function, _func_graph, backward_function, null, 0);
113+
}
114+
115+
public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
116+
{
117+
throw new NotImplementedException("");
118+
}
119+
}
120+
}

src/TensorFlowNET.Core/Functions/c_api.function.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name,
4747
string description,
4848
SafeStatusHandle status);
4949

50+
[DllImport(TensorFlowLibName)]
51+
public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
52+
5053
[DllImport(TensorFlowLibName)]
5154
public static extern IntPtr TF_FunctionName(IntPtr func);
5255

src/TensorFlowNET.Core/Gradients/ITape.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ public interface ITape
1313
void RecordOperation(string op_type,
1414
Tensor[] input_tensors,
1515
TapeTensor[] output_tensors,
16-
long[] input_tensor_id,
17-
TF_DataType[] input_dtypes,
1816
Func<BackwardFunction> backward_function_getter);
1917

2018
void VariableAccessed(ResourceVariable variable);

src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Tensorflow.Util;
44
using static Tensorflow.tensorflow;
55
using static Tensorflow.Binding;
6+
using System.Linq;
67

78
namespace Tensorflow.Gradients
89
{
@@ -14,18 +15,19 @@ public partial class Tape
1415
public void RecordOperation(string op_type,
1516
Tensor[] input_tensors,
1617
TapeTensor[] output_tensors,
17-
long[] input_tensor_id,
18-
TF_DataType[] input_dtypes,
1918
Func<BackwardFunction> backward_function_getter)
2019
{
21-
if (!ShouldRecord(input_tensor_id, input_dtypes))
20+
var input_ids = input_tensors.Select(x => x.Id).ToArray();
21+
var input_dtypes = input_tensors.Select(x => x.dtype).ToArray();
22+
23+
if (!ShouldRecord(input_ids, input_dtypes))
2224
{
2325
return;
2426
}
2527

2628
long op_id = next_op_id_++;
27-
var ids = new List<long>(input_tensor_id.Length);
28-
foreach (var i in input_tensor_id)
29+
var ids = new List<long>(input_ids.Length);
30+
foreach (var i in input_ids)
2931
{
3032
tensor_usage_[i]++;
3133
ids.Add(i);

0 commit comments

Comments
 (0)