Skip to content

Commit 3ad3845

Browse files
authored
Improve error message when layer/model input validation fails. (#21869)
Provide the input name and the input path when available.
1 parent a40ddf6 commit 3ad3845

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

keras/src/layers/input_spec.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,24 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
185185
if spec.ndim is not None and not spec.allow_last_axis_squeeze:
186186
if ndim != spec.ndim:
187187
raise ValueError(
188-
f'Input {input_index} of layer "{layer_name}" '
189-
"is incompatible with the layer: "
188+
f"Input {input_index} with name '{spec.name}' of layer "
189+
f"'{layer_name}' is incompatible with the layer: "
190190
f"expected ndim={spec.ndim}, found ndim={ndim}. "
191191
f"Full shape received: {shape}"
192192
)
193193
if spec.max_ndim is not None:
194194
if ndim is not None and ndim > spec.max_ndim:
195195
raise ValueError(
196-
f'Input {input_index} of layer "{layer_name}" '
197-
"is incompatible with the layer: "
196+
f"Input {input_index} with name '{spec.name}' of layer "
197+
f"'{layer_name}' is incompatible with the layer: "
198198
f"expected max_ndim={spec.max_ndim}, "
199199
f"found ndim={ndim}"
200200
)
201201
if spec.min_ndim is not None:
202202
if ndim is not None and ndim < spec.min_ndim:
203203
raise ValueError(
204-
f'Input {input_index} of layer "{layer_name}" '
205-
"is incompatible with the layer: "
204+
f"Input {input_index} with name '{spec.name}' of layer "
205+
f"'{layer_name}' is incompatible with the layer: "
206206
f"expected min_ndim={spec.min_ndim}, "
207207
f"found ndim={ndim}. "
208208
f"Full shape received: {shape}"
@@ -212,8 +212,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
212212
dtype = backend.standardize_dtype(x.dtype)
213213
if dtype != spec.dtype:
214214
raise ValueError(
215-
f'Input {input_index} of layer "{layer_name}" '
216-
"is incompatible with the layer: "
215+
f"Input {input_index} with name '{spec.name}' of layer "
216+
f"'{layer_name}' is incompatible with the layer: "
217217
f"expected dtype={spec.dtype}, "
218218
f"found dtype={dtype}"
219219
)
@@ -226,11 +226,10 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
226226
None,
227227
}:
228228
raise ValueError(
229-
f'Input {input_index} of layer "{layer_name}" is '
230-
f"incompatible with the layer: expected axis {axis} "
231-
f"of input shape to have value {value}, "
232-
"but received input with "
233-
f"shape {shape}"
229+
f"Input {input_index} with name '{spec.name}' of layer "
230+
f"'{layer_name}' is incompatible with the layer: "
231+
f"expected axis {axis} of input shape to have value "
232+
f"{value}, but received input with shape {shape}"
234233
)
235234
# Check shape.
236235
if spec.shape is not None:
@@ -244,8 +243,8 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
244243
if spec_dim is not None and dim is not None:
245244
if spec_dim != dim:
246245
raise ValueError(
247-
f'Input {input_index} of layer "{layer_name}" is '
248-
"incompatible with the layer: "
249-
f"expected shape={spec.shape}, "
250-
f"found shape={shape}"
246+
f"Input {input_index} with name '{spec.name}' of "
247+
f"layer '{layer_name}' is incompatible with the "
248+
f"layer: expected shape={spec.shape}, found "
249+
f"shape={shape}"
251250
)

keras/src/models/functional.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,9 @@ def _convert_inputs_to_tensors(self, flat_inputs):
254254
return converted
255255

256256
def _adjust_input_rank(self, flat_inputs):
257-
flat_ref_shapes = [x.shape for x in self._inputs]
258257
adjusted = []
259-
for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
258+
for i, x in enumerate(flat_inputs):
259+
ref_shape = self._inputs[i].shape
260260
if x is None:
261261
adjusted.append(x)
262262
continue
@@ -273,8 +273,11 @@ def _adjust_input_rank(self, flat_inputs):
273273
if ref_shape[-1] == 1:
274274
adjusted.append(ops.expand_dims(x, axis=-1))
275275
continue
276+
flat_paths_and_inputs = tree.flatten_with_path(self._inputs_struct)
277+
path = ".".join(str(p) for p in flat_paths_and_inputs[i][0])
276278
raise ValueError(
277-
f"Invalid input shape for input {x}. Expected shape "
279+
f"Invalid input shape for input {x} with name "
280+
f"'{self._inputs[i].name}' and path '{path}'. Expected shape "
278281
f"{ref_shape}, but input has incompatible shape {x.shape}"
279282
)
280283
# Add back metadata.

keras/src/models/functional_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,24 @@ def test_rank_standardization(self):
373373
out_val = model(np.random.random((2, 3)))
374374
self.assertEqual(out_val.shape, (2, 3, 3))
375375

376+
@pytest.mark.requires_trainable_backend
377+
def test_rank_standardization_failure(self):
378+
# Simple input and rank too high
379+
inputs = Input(shape=(3,), name="foo")
380+
outputs = layers.Dense(3)(inputs)
381+
model = Functional(inputs, outputs)
382+
with self.assertRaisesRegex(ValueError, "name 'foo' .* path ''"):
383+
model(np.random.random((2, 3, 4)))
384+
385+
# Deeply nested input and rank too low
386+
inputs = [{"foo": Input(shape=(3,), name="my_input")}]
387+
outputs = layers.Dense(3)(inputs[0]["foo"])
388+
model = Functional(inputs, outputs)
389+
with self.assertRaisesRegex(
390+
ValueError, "name 'my_input' .* path '0.foo'"
391+
):
392+
model(np.random.random(()))
393+
376394
@pytest.mark.requires_trainable_backend
377395
def test_dtype_standardization(self):
378396
float_input = Input(shape=(2,), dtype="float16")

0 commit comments

Comments
 (0)