Skip to content

Commit 4d86da6

Browse files
committed
Override create_op in FuncGraph.
1 parent fafed7d commit 4d86da6

File tree

3 files changed

+138
-4
lines changed

3 files changed

+138
-4
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Exceptions
6+
{
7+
public class InaccessibleTensorError : TensorflowException
8+
{
9+
public InaccessibleTensorError(string message) : base(message)
10+
{
11+
12+
}
13+
}
14+
}

src/TensorFlowNET.Core/Graphs/FuncGraph.cs

Lines changed: 123 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
using System;
1+
using Google.Protobuf;
2+
using System;
23
using System.Collections.Generic;
34
using System.Linq;
5+
using Tensorflow.Eager;
6+
using Tensorflow.Exceptions;
47
using static Tensorflow.Binding;
58

69
namespace Tensorflow.Graphs
@@ -21,7 +24,9 @@ public class FuncGraph : Graph
2124

2225
public Tensors Inputs { get; set; }
2326
public Tensors Outputs { get; set; }
27+
public Dictionary<string, string> Attrs { get; set; }
2428

29+
Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>();
2530
/// <summary>
2631
/// Construct a new FuncGraph.
2732
/// </summary>
@@ -34,10 +39,14 @@ public FuncGraph(string name) : base()
3439
as_default();
3540
}
3641

37-
public FuncGraph(IntPtr handle, string name)
42+
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base()
3843
{
3944
outer_graph = ops.get_default_graph();
4045
func_name = name;
46+
Attrs = attrs;
47+
// Will to test if FuncGraph has memory leak
48+
// c_api.TF_DeleteGraph(_handle);
49+
_handle = handle;
4150

4251
tf.Context.graph_mode();
4352
as_default();
@@ -63,6 +72,8 @@ public IntPtr ToGraph(Operation[] opers,
6372
status.Handle);
6473
status.Check(true);
6574

75+
SetAttrs();
76+
6677
c_api.TF_GraphCopyFunction(outer_graph, func_handle, IntPtr.Zero, status.Handle);
6778
status.Check(true);
6879

@@ -73,13 +84,122 @@ public IntPtr ToGraph(Operation[] opers,
7384

7485
Inputs = inputs;
7586
// mark_as_return
76-
Outputs = outputs.Select(x => array_ops.identity(x)).ToArray();
87+
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();
7788

7889
tf.Context.restore_mode();
7990

8091
return func_handle;
8192
}
8293

94+
public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true)
95+
{
96+
foreach(var (i, inp) in enumerate(inputs))
97+
inputs[i] = capture(inp);
98+
99+
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
100+
}
101+
102+
Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid)
103+
{
104+
if(tensor is EagerTensor)
105+
{
106+
throw new NotImplementedException("");
107+
}
108+
109+
if(tensor.graph != this)
110+
{
111+
if (name == null)
112+
name = tensor.op.name;
113+
var inner_graph = tensor.graph;
114+
while(inner_graph != null && inner_graph is FuncGraph inner_func_graph)
115+
{
116+
if (inner_graph == this)
117+
throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" +
118+
" in another function or code block. Use return values," +
119+
" explicit Python locals or TensorFlow collections to access" +
120+
$" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}.");
121+
inner_graph = inner_func_graph.outer_graph;
122+
}
123+
return _capture_helper(tensor, name);
124+
}
125+
126+
return tensor;
127+
}
128+
129+
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
130+
{
131+
Tensor placeholder = null;
132+
if (!_captures.ContainsKey(tensor.Id))
133+
{
134+
placeholder = _create_substitute_placeholder(tensor,
135+
name: name,
136+
dtype: tensor.dtype,
137+
shape: shape);
138+
add_capture(tensor, placeholder);
139+
}
140+
else
141+
{
142+
placeholder = _captures[tensor.Id].Item1;
143+
}
144+
145+
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
146+
{
147+
return output_grads;
148+
};
149+
150+
tf.Runner.RecordGradient("captured_value",
151+
new[] { placeholder }, null,
152+
new[] { tensor },
153+
getBackwardFunction: () => _backward_function_wrapper
154+
/*getForwardFunction: forward_function*/);
155+
156+
return placeholder;
157+
}
158+
159+
void add_capture(Tensor tensor, Tensor placeholder)
160+
{
161+
_captures[tensor.Id] = (tensor, placeholder);
162+
if (Inputs == null)
163+
Inputs = new Tensors(placeholder);
164+
else
165+
{
166+
var inputs = Inputs.ToList();
167+
inputs.Add(placeholder);
168+
Inputs = new Tensors(inputs.ToArray());
169+
}
170+
}
171+
172+
Tensor _create_substitute_placeholder(Tensor value,
173+
string name = null,
174+
TF_DataType dtype = TF_DataType.DtInvalid,
175+
TensorShape shape = null)
176+
{
177+
if (shape is null)
178+
shape = value.shape;
179+
if (dtype == TF_DataType.DtInvalid)
180+
dtype = value.dtype;
181+
182+
var placeholder = tf_with(ops.control_dependencies(null), ctl => array_ops.placeholder(dtype, shape: shape, name: name));
183+
// custom_gradient.copy_handle_data(value, placeholder)
184+
return placeholder;
185+
}
186+
187+
void SetAttrs()
188+
{
189+
if (Attrs == null)
190+
return;
191+
192+
foreach (var (_name, attr_value) in enumerate(Attrs))
193+
{
194+
var serialized = new AttrValue
195+
{
196+
S = ByteString.CopyFromUtf8(attr_value)
197+
}.ToByteArray();
198+
c_api.TF_FunctionSetAttrValueProto(func_handle, _name, serialized, serialized.Length, tf.Status.Handle);
199+
tf.Status.Check(true);
200+
}
201+
}
202+
83203
protected override void DisposeManagedResources()
84204
{
85205
base.DisposeManagedResources();

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ private void _check_not_finalized()
262262
throw new RuntimeError("Graph is finalized and cannot be modified.");
263263
}
264264

265-
public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
265+
public virtual Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
266266
TF_DataType[] input_types = null, string name = null,
267267
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null,
268268
bool compute_device = true)

0 commit comments

Comments
 (0)