Skip to content

Commit 8c6df5d

Browse files
committed
-- debugging
1 parent 4c34ee5 commit 8c6df5d

File tree

1 file changed

+2
-21
lines changed
  • pytorch_tabular/models/mixture_density

1 file changed

+2
-21
lines changed

pytorch_tabular/models/mixture_density/mdn.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -92,32 +92,13 @@ def sample(self, pi, sigma, mu):
9292
sample = sample * sigma.gather(1, pis) + mu.gather(1, pis)
9393
return sample
9494

95-
def calculate_loss(self, y, pi, sigma, mu, tag="train"):
96-
# NLL Loss
97-
log_prob = self.log_prob(pi, sigma, mu, y)
98-
loss = torch.mean(-log_prob)
99-
# pi1, pi2 = torch.mean(pi, dim=0)
100-
# loss = torch.mean(-log_prob) + torch.abs(pi1-pi2)
101-
# log_sigma = torch.log(sigma)
102-
# kl_div = log_sigma[:,0] - log_sigma[:,1]+ torch.pow(sigma[:,0],2) + torch.pow((mu[:,0]-mu[:,1]),2)/ 2*torch.pow(sigma[:,1],2) - 0.5
103-
# kl_div = torch.mean(kl_div)
104-
# loss = torch.mean(-log_prob) - 1e-8*kl_div
105-
self.log(
106-
f"{tag}_loss",
107-
loss,
108-
on_epoch=(tag == "valid"),
109-
on_step=(tag == "train"),
110-
# on_step=False,
111-
logger=True,
112-
prog_bar=True,
113-
)
114-
return loss
115-
11695
def generate_samples(self, pi, sigma, mu, n_samples=None):
11796
if n_samples is None:
11897
n_samples = self.hparams.n_samples
11998
samples = []
12099
softmax_pi = nn.functional.gumbel_softmax(pi, tau=1, dim=-1)
100+
if (softmax_pi<0).sum().item()>0:
101+
print("pi has negative")
121102
for _ in range(n_samples):
122103
samples.append(self.sample(softmax_pi, sigma, mu))
123104
samples = torch.cat(samples, dim=1)

0 commit comments

Comments
 (0)