Skip to content

Commit b0b341e

Browse files
committed
init film to identity
1 parent da5ef5e commit b0b341e

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

meshgpt_pytorch/meshgpt_pytorch.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -251,18 +251,24 @@ class FiLM(Module):
251251
def __init__(self, dim, dim_out = None):
252252
super().__init__()
253253
dim_out = default(dim_out, dim)
254-
linear = nn.Linear(dim, dim_out * 2)
255254

256-
self.to_gamma_beta = nn.Sequential(
257-
linear,
258-
Rearrange('b (gb d) -> gb b 1 d', gb = 2)
259-
)
255+
self.to_gamma = nn.Linear(dim, dim_out, bias = False)
256+
self.to_beta = nn.Linear(dim, dim_out)
260257

261-
nn.init.zeros_(linear.weight)
262-
nn.init.constant_(linear.bias, 1.)
258+
self.gamma_mult = nn.Parameter(torch.zeros(1,))
259+
self.beta_mult = nn.Parameter(torch.zeros(1,))
263260

264261
def forward(self, x, cond):
265-
gamma, beta = self.to_gamma_beta(cond)
262+
gamma, beta = self.to_gamma(cond), self.to_beta(cond)
263+
gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta))
264+
265+
# for initializing to identity
266+
267+
gamma = (1 + self.gamma_mult * gamma)
268+
beta = beta * self.beta_mult
269+
270+
# classic film
271+
266272
return x * gamma + beta
267273

268274
class PixelNorm(Module):

meshgpt_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.11'
1+
__version__ = '1.2.12'

0 commit comments

Comments
 (0)