@@ -67,6 +67,9 @@ def default(v, d):
6767def first (it ):
6868 return it [0 ]
6969
70+ def identity (t , * args , ** kwargs ):
71+ return t
72+
7073def divisible_by (num , den ):
7174 return (num % den ) == 0
7275
@@ -264,8 +267,8 @@ def forward(self, x, cond):
264267
265268 # for initializing to identity
266269
267- gamma = (1 + self .gamma_mult * gamma )
268- beta = beta * self .beta_mult
270+ gamma = (1 + self .gamma_mult * gamma . tanh () )
271+ beta = beta . tanh () * self .beta_mult
269272
270273 # classic film
271274
@@ -1067,6 +1070,7 @@ def __init__(
10671070 pad_id = - 1 ,
10681071 num_sos_tokens = None ,
10691072 condition_on_text = False ,
1073+ text_cond_with_film = False ,
10701074 text_condition_model_types = ('t5' ,),
10711075 text_condition_cond_drop_prob = 0.25 ,
10721076 quads = False ,
@@ -1124,8 +1128,8 @@ def __init__(
11241128 dim_text = self .conditioner .dim_latent
11251129 cross_attn_dim_context = dim_text
11261130
1127- self .text_coarse_film_cond = FiLM (dim_text , dim )
1128- self .text_fine_film_cond = FiLM (dim_text , dim_fine )
1131+ self .text_coarse_film_cond = FiLM (dim_text , dim ) if text_cond_with_film else identity
1132+ self .text_fine_film_cond = FiLM (dim_text , dim_fine ) if text_cond_with_film else identity
11291133
11301134 # for summarizing the vertices of each face
11311135
0 commit comments