Skip to content

Commit 9ac8cc4

Browse files
committed
-- fixed an issue with MDN
1 parent c4f1583 commit 9ac8cc4

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

pytorch_tabular/models/mixture_density/mdn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def generate_point_predictions(self, pi, sigma, mu, n_samples=None):
129129
y_hat = torch.mean(samples, dim=-1)
130130
elif self.hparams.central_tendency == "median":
131131
y_hat = torch.median(samples, dim=-1).values
132-
return y_hat
132+
return y_hat.unsqueeze(1)
133133

134134

135135
class BaseMDN(BaseModel, metaclass=ABCMeta):

tests/test_mdn.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_classification(
152152

153153
# test_regression(
154154
# regression_data(),
155-
# multi_target=True,
155+
# multi_target=False,
156156
# continuous_cols=[
157157
# "AveRooms",
158158
# "AveBedrms",
@@ -161,10 +161,11 @@ def test_classification(
161161
# "Latitude",
162162
# "Longitude",
163163
# ],
164-
# categorical_cols=[],
165-
# continuous_feature_transform="yeo-johnson",
166-
# normalize_continuous_features=False,
167-
# target_range=True,
164+
# categorical_cols=["HouseAgeBin"],
165+
# continuous_feature_transform=None,
166+
# normalize_continuous_features=True,
167+
# variant=CategoryEmbeddingMDNConfig,
168+
# num_gaussian=2
168169
# )
169170
# test_embedding_transformer(regression_data())
170171

0 commit comments

Comments
 (0)