Skip to content

Commit bed82b9

Browse files
committed
Fix in TrainerPerfectClassification.
1 parent 79aba0d commit bed82b9

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

neural_structured_learning/research/gam/trainer/trainer_classification.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,4 +763,8 @@ def train(self, unused_data, unused_session=None, **unused_kwargs):
763763
return 1.0, 1.0
764764

765765
def predict(self, unused_session, indices_unlabeled):
766-
return self.data.get_original_labels(indices_unlabeled)
766+
labels = self.data.get_original_labels(indices_unlabeled)
767+
num_samples = len(indices_unlabeled)
768+
predictions = np.zeros((num_samples, self.data.num_classes))
769+
predictions[np.arange(num_samples), labels] = 1.0
770+
return predictions

0 commit comments

Comments
 (0)