@@ -251,18 +251,24 @@ class FiLM(Module):
251251 def __init__ (self , dim , dim_out = None ):
252252 super ().__init__ ()
253253 dim_out = default (dim_out , dim )
254- linear = nn .Linear (dim , dim_out * 2 )
255254
256- self .to_gamma_beta = nn .Sequential (
257- linear ,
258- Rearrange ('b (gb d) -> gb b 1 d' , gb = 2 )
259- )
255+ self .to_gamma = nn .Linear (dim , dim_out , bias = False )
256+ self .to_beta = nn .Linear (dim , dim_out )
260257
261- nn .init . zeros_ ( linear . weight )
262- nn .init . constant_ ( linear . bias , 1. )
258+ self . gamma_mult = nn .Parameter ( torch . zeros ( 1 ,) )
259+ self . beta_mult = nn .Parameter ( torch . zeros ( 1 ,) )
263260
264261 def forward (self , x , cond ):
265- gamma , beta = self .to_gamma_beta (cond )
262+ gamma , beta = self .to_gamma (cond ), self .to_beta (cond )
263+ gamma , beta = tuple (rearrange (t , 'b d -> b 1 d' ) for t in (gamma , beta ))
264+
265+ # for initializing to identity
266+
267+ gamma = (1 + self .gamma_mult * gamma )
268+ beta = beta * self .beta_mult
269+
270+ # classic film
271+
266272 return x * gamma + beta
267273
268274class PixelNorm (Module ):
0 commit comments