1- using System ;
1+ using Google . Protobuf ;
2+ using System ;
23using System . Collections . Generic ;
34using System . Linq ;
5+ using Tensorflow . Eager ;
6+ using Tensorflow . Exceptions ;
47using static Tensorflow . Binding ;
58
69namespace Tensorflow . Graphs
@@ -21,7 +24,9 @@ public class FuncGraph : Graph
2124
2225 public Tensors Inputs { get ; set ; }
2326 public Tensors Outputs { get ; set ; }
27+ public Dictionary < string , string > Attrs { get ; set ; }
2428
29+ Dictionary < long , ( Tensor , Tensor ) > _captures = new Dictionary < long , ( Tensor , Tensor ) > ( ) ;
2530 /// <summary>
2631 /// Construct a new FuncGraph.
2732 /// </summary>
@@ -34,10 +39,14 @@ public FuncGraph(string name) : base()
3439 as_default ( ) ;
3540 }
3641
37- public FuncGraph ( IntPtr handle , string name )
42+ public FuncGraph ( IntPtr handle , string name , Dictionary < string , string > attrs ) : base ( )
3843 {
3944 outer_graph = ops . get_default_graph ( ) ;
4045 func_name = name ;
46+ Attrs = attrs ;
47+ // Will to test if FuncGraph has memory leak
48+ // c_api.TF_DeleteGraph(_handle);
49+ _handle = handle ;
4150
4251 tf . Context . graph_mode ( ) ;
4352 as_default ( ) ;
@@ -63,6 +72,8 @@ public IntPtr ToGraph(Operation[] opers,
6372 status . Handle ) ;
6473 status . Check ( true ) ;
6574
75+ SetAttrs ( ) ;
76+
6677 c_api . TF_GraphCopyFunction ( outer_graph , func_handle , IntPtr . Zero , status . Handle ) ;
6778 status . Check ( true ) ;
6879
@@ -73,13 +84,122 @@ public IntPtr ToGraph(Operation[] opers,
7384
7485 Inputs = inputs ;
7586 // mark_as_return
76- Outputs = outputs . Select ( x => array_ops . identity ( x ) ) . ToArray ( ) ;
87+ Outputs = outputs ; // .Select(x => array_ops.identity(x)).ToArray();
7788
7889 tf . Context . restore_mode ( ) ;
7990
8091 return func_handle ;
8192 }
8293
94+ public override Operation create_op ( string op_type , Tensor [ ] inputs , TF_DataType [ ] dtypes , TF_DataType [ ] input_types = null , string name = null , Dictionary < string , AttrValue > attrs = null , OpDef op_def = null , bool compute_device = true )
95+ {
96+ foreach ( var ( i , inp ) in enumerate ( inputs ) )
97+ inputs [ i ] = capture ( inp ) ;
98+
99+ return base . create_op ( op_type , inputs , dtypes , input_types , name , attrs , op_def , compute_device ) ;
100+ }
101+
102+ Tensor capture ( Tensor tensor , string name = null , TF_DataType shape = TF_DataType . DtInvalid )
103+ {
104+ if ( tensor is EagerTensor )
105+ {
106+ throw new NotImplementedException ( "" ) ;
107+ }
108+
109+ if ( tensor . graph != this )
110+ {
111+ if ( name == null )
112+ name = tensor . op . name ;
113+ var inner_graph = tensor . graph ;
114+ while ( inner_graph != null && inner_graph is FuncGraph inner_func_graph )
115+ {
116+ if ( inner_graph == this )
117+ throw new InaccessibleTensorError ( $ "The tensor '{ tensor . name } ' cannot be accessed here: it is defined" +
118+ " in another function or code block. Use return values," +
119+ " explicit Python locals or TensorFlow collections to access" +
120+ $ " it. Defined in: { tensor . graph . graph_key } ; accessed from: { graph_key } .") ;
121+ inner_graph = inner_func_graph . outer_graph ;
122+ }
123+ return _capture_helper ( tensor , name ) ;
124+ }
125+
126+ return tensor ;
127+ }
128+
129+ Tensor _capture_helper ( Tensor tensor , string name , TensorShape shape = null )
130+ {
131+ Tensor placeholder = null ;
132+ if ( ! _captures . ContainsKey ( tensor . Id ) )
133+ {
134+ placeholder = _create_substitute_placeholder ( tensor ,
135+ name : name ,
136+ dtype : tensor . dtype ,
137+ shape : shape ) ;
138+ add_capture ( tensor , placeholder ) ;
139+ }
140+ else
141+ {
142+ placeholder = _captures [ tensor . Id ] . Item1 ;
143+ }
144+
145+ BackwardFunction _backward_function_wrapper = ( output_grads , unneeded_gradients ) =>
146+ {
147+ return output_grads ;
148+ } ;
149+
150+ tf . Runner . RecordGradient ( "captured_value" ,
151+ new [ ] { placeholder } , null ,
152+ new [ ] { tensor } ,
153+ getBackwardFunction : ( ) => _backward_function_wrapper
154+ /*getForwardFunction: forward_function*/ ) ;
155+
156+ return placeholder ;
157+ }
158+
159+ void add_capture ( Tensor tensor , Tensor placeholder )
160+ {
161+ _captures [ tensor . Id ] = ( tensor , placeholder ) ;
162+ if ( Inputs == null )
163+ Inputs = new Tensors ( placeholder ) ;
164+ else
165+ {
166+ var inputs = Inputs . ToList ( ) ;
167+ inputs . Add ( placeholder ) ;
168+ Inputs = new Tensors ( inputs . ToArray ( ) ) ;
169+ }
170+ }
171+
172+ Tensor _create_substitute_placeholder ( Tensor value ,
173+ string name = null ,
174+ TF_DataType dtype = TF_DataType . DtInvalid ,
175+ TensorShape shape = null )
176+ {
177+ if ( shape is null )
178+ shape = value . shape ;
179+ if ( dtype == TF_DataType . DtInvalid )
180+ dtype = value . dtype ;
181+
182+ var placeholder = tf_with ( ops . control_dependencies ( null ) , ctl => array_ops . placeholder ( dtype , shape : shape , name : name ) ) ;
183+ // custom_gradient.copy_handle_data(value, placeholder)
184+ return placeholder ;
185+ }
186+
187+ void SetAttrs ( )
188+ {
189+ if ( Attrs == null )
190+ return ;
191+
192+ foreach ( var ( _name , attr_value ) in enumerate ( Attrs ) )
193+ {
194+ var serialized = new AttrValue
195+ {
196+ S = ByteString . CopyFromUtf8 ( attr_value )
197+ } . ToByteArray ( ) ;
198+ c_api . TF_FunctionSetAttrValueProto ( func_handle , _name , serialized , serialized . Length , tf . Status . Handle ) ;
199+ tf . Status . Check ( true ) ;
200+ }
201+ }
202+
83203 protected override void DisposeManagedResources ( )
84204 {
85205 base . DisposeManagedResources ( ) ;
0 commit comments