Skip to content

Commit 4aef886

Browse files
authored
Allow None inputs in Layer.build. (#21866)
If `call` takes a structure as inputs and some of the inputs are `None`, currently, building on call will fail.
1 parent 9bcdbc7 commit 4aef886

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

keras/src/layers/layer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,6 +1893,10 @@ def get_shapes_dict(call_spec):
18931893
{"input_a_shape": (2, 3)}
18941894
```
18951895
"""
1896+
1897+
def standardize_shape_or_none(x):
1898+
return None if x is None else backend.standardize_shape(x.shape)
1899+
18961900
shapes_dict = {}
18971901
for k, v in call_spec.tensor_arguments_dict.items():
18981902
if k == "mask" or k.endswith("_mask"):
@@ -1903,10 +1907,10 @@ def get_shapes_dict(call_spec):
19031907
continue
19041908
if k in call_spec.nested_tensor_argument_names:
19051909
shapes_dict[f"{k}_shape"] = tree.map_structure(
1906-
lambda x: backend.standardize_shape(x.shape), v
1910+
standardize_shape_or_none, v
19071911
)
19081912
else:
1909-
shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape)
1913+
shapes_dict[f"{k}_shape"] = standardize_shape_or_none(v)
19101914
return shapes_dict
19111915

19121916

keras/src/layers/layer_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,49 @@ def call(self, x1, x2):
516516
layer(x1=backend.KerasTensor((3, 4)), x2=backend.KerasTensor((3, 4)))
517517
self.assertLen(layer.weights, 4)
518518

519+
class DictLayerWithUnbuiltState(layers.Layer):
520+
def __init__(self, units):
521+
super().__init__()
522+
self.dense = layers.Dense(units)
523+
524+
def call(self, xs):
525+
result = self.dense(xs["x1"])
526+
if xs.get("x2", None) is not None:
527+
result += self.dense(xs["x2"])
528+
return result
529+
530+
layer = DictLayerWithUnbuiltState(2)
531+
layer(
532+
{
533+
"x1": backend.KerasTensor((3, 4)),
534+
"x2": backend.KerasTensor((3, 4)),
535+
}
536+
)
537+
self.assertLen(layer.weights, 2)
538+
539+
layer = DictLayerWithUnbuiltState(2)
540+
layer({"x1": backend.KerasTensor((3, 4)), "x2": None})
541+
self.assertLen(layer.weights, 2)
542+
543+
class ListLayerWithUnbuiltState(layers.Layer):
544+
def __init__(self, units):
545+
super().__init__()
546+
self.dense = layers.Dense(units)
547+
548+
def call(self, xs):
549+
result = self.dense(xs[0])
550+
if xs[1] is not None:
551+
result += self.dense(xs[1])
552+
return result
553+
554+
layer = ListLayerWithUnbuiltState(2)
555+
layer([backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))])
556+
self.assertLen(layer.weights, 2)
557+
558+
layer = ListLayerWithUnbuiltState(2)
559+
layer([backend.KerasTensor((3, 4)), None])
560+
self.assertLen(layer.weights, 2)
561+
519562
def test_activity_regularization(self):
520563
class ActivityRegularizer(layers.Layer):
521564
def call(self, x):

0 commit comments

Comments
 (0)