You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: pytorch_tabular/models/mixture_density/config.py
+66-5Lines changed: 66 additions & 5 deletions
Original file line number
Diff line number
Diff line change
@@ -18,7 +18,21 @@ class MixtureDensityHeadConfig:
18
18
Args:
19
19
num_gaussian (int): Number of Gaussian Distributions in the mixture model. Defaults to 1
20
20
n_samples (int): Number of samples to draw from the posterior to get prediction. Defaults to 100
21
-
central_tendency (str): Which measure to use to get the point prediction. Choices are 'mean', 'median'. Defaults to `mean`
21
+
central_tendency (str): Which measure to use to get the point prediction.
22
+
Choices are 'mean', 'median'. Defaults to `mean`
23
+
sigma_bias_flag (bool): Whether to have a bias term in the sigma layer. Defaults to False
24
+
mu_bias_init (Optional[List]): To initialize the bias parameter of the mu layer to predefined cluster centers.
25
+
Should be a list with the same length as number of gaussians in the mixture model.
26
+
It is highly recommended to set the parameter to combat mode collapse. Defaults to None
27
+
weight_regularization (Optional[int]): Whether to apply L1 or L2 Norm to the MDN layers.
28
+
It is highly recommended to use this to avoid mode collapse. Choices are [1,2]. Defaults to L2
29
+
lambda_sigma (Optional[float]): The regularization constant for weight regularization of sigma layer. Defaults to 0.1
30
+
lambda_pi (Optional[float]): The regularization constant for weight regularization of pi layer. Defaults to 0.1
31
+
lambda_mu (Optional[float]): The regularization constant for weight regularization of mu layer. Defaults to 0.1
32
+
speedup_training (bool): Turning on this parameter does away with sampling during training which speeds up training,
33
+
but also doesn't give you visibility on train metrics. Defaults to False
34
+
log_debug_plot (bool): Turning on this parameter plots histograms of the mu, sigma, and pi layers in addition to the logits
35
+
(if log_logits is turned on in experment config). Defaults to False
22
36
23
37
"""
24
38
@@ -28,6 +42,45 @@ class MixtureDensityHeadConfig:
28
42
"help": "Number of Gaussian Distributions in the mixture model. Defaults to 1",
29
43
},
30
44
)
45
+
sigma_bias_flag: bool=field(
46
+
default=False,
47
+
metadata={
48
+
"help": "Whether to have a bias term in the sigma layer. Defaults to False",
49
+
},
50
+
)
51
+
mu_bias_init: Optional[List] =field(
52
+
default=None,
53
+
metadata={
54
+
"help": "To initialize the bias parameter of the mu layer to predefined cluster centers. Should be a list with the same length as number of gaussians in the mixture model. It is highly recommended to set the parameter to combat mode collapse. Defaults to None",
55
+
},
56
+
)
57
+
58
+
weight_regularization: Optional[int] =field(
59
+
default=2,
60
+
metadata={
61
+
"help": "Whether to apply L1 or L2 Norm to the MDN layers. Defaults to L2",
62
+
"choices": [1, 2],
63
+
},
64
+
)
65
+
66
+
lambda_sigma: Optional[float] =field(
67
+
default=0.1,
68
+
metadata={
69
+
"help": "The regularization constant for weight regularization of sigma layer. Defaults to 0.1",
70
+
},
71
+
)
72
+
lambda_pi: Optional[float] =field(
73
+
default=0.1,
74
+
metadata={
75
+
"help": "The regularization constant for weight regularization of pi layer. Defaults to 0.1",
76
+
},
77
+
)
78
+
lambda_mu: Optional[float] =field(
79
+
default=0,
80
+
metadata={
81
+
"help": "The regularization constant for weight regularization of mu layer. Defaults to 0",
82
+
},
83
+
)
31
84
n_samples: int=field(
32
85
default=100,
33
86
metadata={
@@ -41,10 +94,16 @@ class MixtureDensityHeadConfig:
41
94
"choices": ["mean", "median"],
42
95
},
43
96
)
44
-
fast_training: bool=field(
97
+
speedup_training: bool=field(
98
+
default=False,
99
+
metadata={
100
+
"help": "Turning on this parameter does away with sampling during training which speeds up training, but also doesn't give you visibility on train metrics. Defaults to False",
101
+
},
102
+
)
103
+
log_debug_plot: bool=field(
45
104
default=False,
46
105
metadata={
47
-
"help": "Turning onthis parameter does away with sampling during training which speeds up training, but also doesn't give you visibility on training metrics. Defaults to True",
106
+
"help": "Turning on this parameter plots histograms of the mu, sigma, and pi layers in addition to the logits(if log_logits is turned on in experment config). Defaults to False",
48
107
},
49
108
)
50
109
_module_src: str=field(default="mixture_density")
@@ -87,7 +146,8 @@ class CategoryEmbeddingMDNConfig(CategoryEmbeddingModelConfig):
87
146
"""
88
147
89
148
mdn_config: MixtureDensityHeadConfig=field(
90
-
default=None, metadata={"help": "The config for defining the Mixed Density Network Head"}
149
+
default=None,
150
+
metadata={"help": "The config for defining the Mixed Density Network Head"},
0 commit comments