@@ -516,6 +516,49 @@ def call(self, x1, x2):
516516 layer (x1 = backend .KerasTensor ((3 , 4 )), x2 = backend .KerasTensor ((3 , 4 )))
517517 self .assertLen (layer .weights , 4 )
518518
519+ class DictLayerWithUnbuiltState (layers .Layer ):
520+ def __init__ (self , units ):
521+ super ().__init__ ()
522+ self .dense = layers .Dense (units )
523+
524+ def call (self , xs ):
525+ result = self .dense (xs ["x1" ])
526+ if xs .get ("x2" , None ) is not None :
527+ result += self .dense (xs ["x2" ])
528+ return result
529+
530+ layer = DictLayerWithUnbuiltState (2 )
531+ layer (
532+ {
533+ "x1" : backend .KerasTensor ((3 , 4 )),
534+ "x2" : backend .KerasTensor ((3 , 4 )),
535+ }
536+ )
537+ self .assertLen (layer .weights , 2 )
538+
539+ layer = DictLayerWithUnbuiltState (2 )
540+ layer ({"x1" : backend .KerasTensor ((3 , 4 )), "x2" : None })
541+ self .assertLen (layer .weights , 2 )
542+
543+ class ListLayerWithUnbuiltState (layers .Layer ):
544+ def __init__ (self , units ):
545+ super ().__init__ ()
546+ self .dense = layers .Dense (units )
547+
548+ def call (self , xs ):
549+ result = self .dense (xs [0 ])
550+ if xs [1 ] is not None :
551+ result += self .dense (xs [1 ])
552+ return result
553+
554+ layer = ListLayerWithUnbuiltState (2 )
555+ layer ([backend .KerasTensor ((3 , 4 )), backend .KerasTensor ((3 , 4 ))])
556+ self .assertLen (layer .weights , 2 )
557+
558+ layer = ListLayerWithUnbuiltState (2 )
559+ layer ([backend .KerasTensor ((3 , 4 )), None ])
560+ self .assertLen (layer .weights , 2 )
561+
519562 def test_activity_regularization (self ):
520563 class ActivityRegularizer (layers .Layer ):
521564 def call (self , x ):
0 commit comments