@@ -30,6 +30,9 @@ limitations under the License.
3030using Tensorflow.Training.Saving.SavedModel;
3131using Tensorflow.Util;
3232using static Tensorflow.Binding;
33+ using Tensorflow.Framework;
34+ using Tensorflow.Sessions;
35+
3336
3437namespace Tensorflow.Keras.Engine
3538{
@@ -134,21 +137,53 @@ public virtual List<IVariableV1> Weights
134137 }
135138 }
136139
137- public virtual void set_weights(List <NDArray> weights)
140+ public virtual void set_weights(IEnumerable <NDArray> weights)
138141 {
139142 if (Weights.Count() != weights.Count()) throw new ValueError(
140143 $"You called `set_weights` on layer \"{this.name}\"" +
141144 $"with a weight list of length {len(weights)}, but the layer was " +
142145 $"expecting {len(Weights)} weights.");
143- for (int i = 0; i < weights.Count(); i++)
146+
147+
148+
149+ // check if the shapes are compatible
150+ var weight_index = 0;
151+ foreach(var w in weights)
144152 {
145- if (weights[i].shape != Weights[i].shape )
153+ if (! Weights[weight_index].AsTensor().is_compatible_with(w) )
146154 {
147- throw new ValueError($"Layer weight shape {weights[i] .shape} not compatible with provided weight shape {Weights[i ].shape}");
155+ throw new ValueError($"Layer weight shape {w .shape} not compatible with provided weight shape {Weights[weight_index ].shape}");
148156 }
157+ weight_index++;
158+ }
159+
160+ if (tf.executing_eagerly())
161+ {
162+ foreach (var (this_w, v_w) in zip(Weights, weights))
163+ this_w.assign(v_w, read_value: true);
164+ }
165+ else
166+ {
167+ // TODO(Wanglongzhi2001):seems like there exist some bug in graph mode when define model, so uncomment the following when it fixed.
168+
169+ //Tensors assign_ops = new Tensors();
170+ //var feed_dict = new FeedDict();
171+
172+ //Graph g = tf.Graph().as_default();
173+ //foreach (var (this_w, v_w) in zip(Weights, weights))
174+ //{
175+ // var tf_dtype = this_w.dtype;
176+ // var placeholder_shape = v_w.shape;
177+ // var assign_placeholder = tf.placeholder(tf_dtype, placeholder_shape);
178+ // var assign_op = this_w.assign(assign_placeholder);
179+ // assign_ops.Add(assign_op);
180+ // feed_dict.Add(assign_placeholder, v_w);
181+ //}
182+ //var sess = tf.Session().as_default();
183+ //sess.run(assign_ops, feed_dict);
184+
185+ //g.Exit();
149186 }
150- foreach (var (this_w, v_w) in zip(Weights, weights))
151- this_w.assign(v_w, read_value: true);
152187 }
153188
154189 public List<NDArray> get_weights()
0 commit comments