Skip to content

Commit 6b3828a

Browse files
committed
-- added embedding dim to categorymdn
1 parent d266f94 commit 6b3828a

File tree

1 file changed

+1
-0
lines changed
  • pytorch_tabular/models/mixture_density

1 file changed

+1
-0
lines changed

pytorch_tabular/models/mixture_density/mdn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def validation_epoch_end(self, outputs) -> None:
253253

254254
class CategoryEmbeddingMDN(BaseMDN):
255255
def __init__(self, config: DictConfig, **kwargs):
256+
self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
256257
super().__init__(config, **kwargs)
257258

258259
def _build_network(self):

0 commit comments

Comments
 (0)