@@ -92,32 +92,13 @@ def sample(self, pi, sigma, mu):
9292 sample = sample * sigma .gather (1 , pis ) + mu .gather (1 , pis )
9393 return sample
9494
95- def calculate_loss (self , y , pi , sigma , mu , tag = "train" ):
96- # NLL Loss
97- log_prob = self .log_prob (pi , sigma , mu , y )
98- loss = torch .mean (- log_prob )
99- # pi1, pi2 = torch.mean(pi, dim=0)
100- # loss = torch.mean(-log_prob) + torch.abs(pi1-pi2)
101- # log_sigma = torch.log(sigma)
102- # kl_div = log_sigma[:,0] - log_sigma[:,1]+ torch.pow(sigma[:,0],2) + torch.pow((mu[:,0]-mu[:,1]),2)/ 2*torch.pow(sigma[:,1],2) - 0.5
103- # kl_div = torch.mean(kl_div)
104- # loss = torch.mean(-log_prob) - 1e-8*kl_div
105- self .log (
106- f"{ tag } _loss" ,
107- loss ,
108- on_epoch = (tag == "valid" ),
109- on_step = (tag == "train" ),
110- # on_step=False,
111- logger = True ,
112- prog_bar = True ,
113- )
114- return loss
115-
11695 def generate_samples (self , pi , sigma , mu , n_samples = None ):
11796 if n_samples is None :
11897 n_samples = self .hparams .n_samples
11998 samples = []
12099 softmax_pi = nn .functional .gumbel_softmax (pi , tau = 1 , dim = - 1 )
100+ if (softmax_pi < 0 ).sum ().item ()> 0 :
101+ print ("pi has negative" )
121102 for _ in range (n_samples ):
122103 samples .append (self .sample (softmax_pi , sigma , mu ))
123104 samples = torch .cat (samples , dim = 1 )
0 commit comments