Skip to content

Commit ae9a161

Browse files
committed
Add assign_lazy_load for ResourceVariable.
1 parent db7aac2 commit ae9a161

File tree

5 files changed

+56
-11
lines changed

5 files changed

+56
-11
lines changed

src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,22 @@ public Tensor assign<T>(T value, bool use_locking = false, string name = null, b
8282
var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
8383
var assign_op = gen_resource_variable_ops.assign_variable_op(
8484
handle, value_tensor, name: name);
85+
8586
if (read_value)
86-
{
8787
return gen_resource_variable_ops.read_variable_op(handle, dtype);
88-
// var variable = _lazy_read(assign_op, value_tensor);
89-
// return variable;
90-
}
88+
9189
return assign_op;
9290
}
9391

92+
public IVariableV1 assign_lazy_load(Tensor value, string name = null)
93+
{
94+
var value_tensor = ops.convert_to_tensor(value, dtype: dtype);
95+
var assign_op = gen_resource_variable_ops.assign_variable_op(
96+
handle, value_tensor, name: name);
97+
var variable = _lazy_read(assign_op, value_tensor);
98+
return variable;
99+
}
100+
94101
public Tensor value()
95102
=> GraphElement ?? _read_variable_op();
96103

@@ -157,6 +164,25 @@ public Tensor assign_add<T>(T delta, bool use_locking = false, string name = nul
157164
return assign_add_op;
158165
}
159166

167+
public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
168+
{
169+
var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
170+
ops.convert_to_tensor(delta, dtype: dtype), name: name);
171+
172+
if (read_value)
173+
return gen_resource_variable_ops.read_variable_op(handle, dtype);
174+
// return _lazy_read(assign_add_op);
175+
return assign_sub_op;
176+
}
177+
178+
public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
179+
{
180+
var assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(Handle,
181+
ops.convert_to_tensor(delta, dtype: dtype), name: name);
182+
183+
return _lazy_read(assign_sub_op, delta);
184+
}
185+
160186
public override string ToString()
161187
{
162188
if (tf.Context.executing_eagerly())

src/TensorFlowNET.Core/Variables/IVariableV1.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ public interface IVariableV1
4747
TF_DataType dtype { get; }
4848
TensorShape shape { get; }
4949
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
50+
Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
51+
IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null);
5052
Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true);
53+
IVariableV1 assign_lazy_load(Tensor value, string name = null);
5154
Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false);
5255
NDArray numpy();
5356
}

src/TensorFlowNET.Core/Variables/RefVariable.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323

2424
namespace Tensorflow
2525
{
26+
[Obsolete]
2627
public partial class RefVariable : IVariableV1, IProtoBuf<VariableDef, RefVariable>
2728
{
2829
protected string _name;
@@ -428,5 +429,20 @@ public Tensor assign_add<T>(T value, bool use_locking = false, string name = nul
428429

429430
public NDArray numpy()
430431
=> throw new RuntimeError("Graph mode can't use numpy().");
432+
433+
public Tensor assign_sub<T>(T delta, bool use_locking = false, string name = null, bool read_value = true)
434+
{
435+
throw new NotImplementedException();
436+
}
437+
438+
public IVariableV1 assign_sub_lazy_load(Tensor delta, string name = null)
439+
{
440+
throw new NotImplementedException();
441+
}
442+
443+
public IVariableV1 assign_lazy_load(Tensor value, string name = null)
444+
{
445+
throw new NotImplementedException();
446+
}
431447
}
432448
}

src/TensorFlowNET.Core/Variables/state_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public static Tensor assign_sub(IVariableV1 @ref,
9090
value,
9191
use_locking: use_locking,
9292
name: name) :
93-
@ref.assign(value, name: name) as Tensor;
93+
@ref.assign_sub(value, name: name);
9494

9595
//"""Update 'ref' by adding 'value' to it.
9696
//

src/TensorFlowNET.Keras/Layers/BatchNormalization.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,23 +209,23 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
209209
return output;
210210
}
211211

212-
Tensor _assign_new_value(IVariableV1 variable, Tensor value)
212+
void _assign_new_value(IVariableV1 variable, Tensor value)
213213
{
214-
return tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope =>
214+
tf_with(ops.name_scope("AssignNewValue", null, new { variable, value, momentum }), scope =>
215215
{
216216
// var cm = ops.colocate_with(variable);
217-
return state_ops.assign_sub(variable, value, name: scope);
217+
variable.assign_lazy_load(value, name: scope);
218218
});
219219
}
220220

221-
Tensor _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum)
221+
void _assign_moving_average(IVariableV1 variable, Tensor value, Tensor momentum)
222222
{
223-
return tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope =>
223+
tf_with(ops.name_scope("AssignMovingAvg", null, new { variable, value, momentum }), scope =>
224224
{
225225
// var cm = ops.colocate_with(variable);
226226
var decay = ops.convert_to_tensor(1.0f - momentum, name: "decay");
227227
var update_delta = (variable.AsTensor() - math_ops.cast(value, variable.dtype)) * decay;
228-
return state_ops.assign_sub(variable, update_delta, name: scope);
228+
variable.assign_sub_lazy_load(update_delta, name: scope);
229229
});
230230
}
231231
}

0 commit comments

Comments
 (0)