Skip to content

Commit aea62f6

Browse files
committed
Use safe AttrValue to invoke TF_SetAttrValueProto.
1 parent 4d86da6 commit aea62f6

File tree

4 files changed

+7
-8
lines changed

4 files changed

+7
-8
lines changed

src/TensorFlowNET.Core/Attributes/c_api.ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public partial class c_api
6161
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value);
6262

6363
[DllImport(TensorFlowLibName)]
64-
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, uint proto_len, SafeStatusHandle status);
64+
public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
6565

6666
/// <summary>
6767
/// Set `num_dims` to -1 to represent "unknown rank".

src/TensorFlowNET.Core/Contexts/ContextSwitch.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,8 @@ public class ContextSwitch
3333
public Action EnterContextFn { get; set; }
3434

3535
public string DeviceStack { get; set; }
36+
37+
public override string ToString()
38+
=> $"EagerMode: {EagerMode}, IsBuildingFunction: {IsBuildingFunction}";
3639
}
3740
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[
168168
if (op_def == null)
169169
op_def = g.GetOpDef(node_def.Op);
170170

171-
(_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray());
171+
(_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def);
172172
_is_stateful = op_def.IsStateful;
173173

174174
// Initialize self._outputs.

src/TensorFlowNET.Core/ops.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,9 @@ public static (IntPtr, OperationDescription) _create_c_op(Graph graph, NodeDef n
190190
// Add attrs
191191
foreach (var attr in node_def.Attr)
192192
{
193-
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
194-
var protoHandle = Marshal.AllocHGlobal(bytes.Length);
195-
Marshal.Copy(bytes, 0, protoHandle, bytes.Length);
196-
uint len = (uint)bytes.Length;
197-
c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status.Handle);
193+
var bytes = attr.Value.ToByteArray();
194+
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle);
198195
status.Check(true);
199-
Marshal.FreeHGlobal(protoHandle);
200196
}
201197

202198
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);

0 commit comments

Comments
 (0)