From 66406eff46af0762e0945ed3d516762dba680d38 Mon Sep 17 00:00:00 2001 From: Alex Boyd Date: Sat, 2 Aug 2025 23:19:21 -0700 Subject: [PATCH] Correct scaling of output intensities in FullyNN. --- easy_tpp/model/torch_model/torch_fullynn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/easy_tpp/model/torch_model/torch_fullynn.py b/easy_tpp/model/torch_model/torch_fullynn.py index 11af6da..52ae351 100644 --- a/easy_tpp/model/torch_model/torch_fullynn.py +++ b/easy_tpp/model/torch_model/torch_fullynn.py @@ -64,13 +64,13 @@ def forward(self, hidden_states, time_delta_seqs): derivative_integral_lambdas = [] for i in range(integral_lambda.shape[-1]): # iterate over marks derivative_integral_lambdas.append(grad( - integral_lambda[..., i].mean(), + integral_lambda[..., i].sum(), time_delta_seqs, create_graph=True, retain_graph=True)[0]) derivative_integral_lambda = torch.stack(derivative_integral_lambdas, dim=-1) # TODO: Check that it is okay to iterate over marks like this else: - derivative_integral_lambda = grad( - integral_lambda.sum(dim=-1).mean(), + derivative_integral_lambda = grad( + integral_lambda.sum(), time_delta_seqs, create_graph=True, retain_graph=True)[0] derivative_integral_lambda = derivative_integral_lambda.unsqueeze(-1).expand(*derivative_integral_lambda.shape, self.num_event_types) / self.num_event_types