Skip to content

Commit 02822c9

Browse files
committed
-- added plotly as a requirement
-- changed the wandb logits histogram to a plotly plot -- added MDN and made changes to prediction layer
1 parent e37a1bd commit 02822c9

File tree

7 files changed

+108780
-82
lines changed

7 files changed

+108780
-82
lines changed

examples/regression_with_MDN.ipynb

Lines changed: 108484 additions & 0 deletions
Large diffs are not rendered by default.

examples/to_test_regression.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@
5454
normalize_continuous_features=True,
5555
)
5656

57-
mdn_config = MixtureDensityHeadConfig(num_gaussian=2)
58-
model_config = NODEMDNConfig(
59-
task="regression",
60-
# initialization="blah",
61-
mdn_config = mdn_config
62-
)
63-
# model_config.validate()
64-
# model_config = NodeConfig(task="regression", depth=2, embed_categorical=False)
57+
# mdn_config = MixtureDensityHeadConfig(num_gaussian=2)
58+
# model_config = NODEMDNConfig(
59+
# task="regression",
60+
# # initialization="blah",
61+
# mdn_config = mdn_config
62+
# )
63+
# # model_config.validate()
64+
model_config = NodeConfig(task="regression", depth=2, embed_categorical=False)
6565
trainer_config = TrainerConfig(checkpoints=None, max_epochs=5, gpus=1, profiler=None)
6666
# experiment_config = ExperimentConfig(
6767
# project_name="DeepGMM_test",
@@ -84,6 +84,6 @@
8484
result = tabular_model.evaluate(test)
8585
# print(result)
8686
# # print(result[0]['train_loss'])
87-
pred_df = tabular_model.predict(test, quantiles=[0.25])
87+
pred_df = tabular_model.predict(test, quantiles=[0.25], ret_logits=True)
8888
print(pred_df.head())
8989
# pred_df.to_csv("output/temp2.csv")

pytorch_tabular/models/base_model.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
import torch
1111
import torch.nn as nn
1212
from omegaconf import DictConfig
13+
1314
try:
1415
import wandb
16+
import plotly.graph_objects as go
17+
1518
WANDB_INSTALLED = True
1619
except ImportError:
1720
WANDB_INSTALLED = False
@@ -101,14 +104,21 @@ def calculate_loss(self, y, y_hat, tag):
101104

102105
def calculate_metrics(self, y, y_hat, tag):
103106
metrics = []
104-
for metric, metric_str, metric_params in zip(self.metrics, self.hparams.metrics, self.hparams.metrics_params):
107+
for metric, metric_str, metric_params in zip(
108+
self.metrics, self.hparams.metrics, self.hparams.metrics_params
109+
):
105110
if (self.hparams.task == "regression") and (self.hparams.output_dim > 1):
106111
_metrics = []
107112
for i in range(self.hparams.output_dim):
108-
if metric.__name__==pl.metrics.functional.mean_squared_log_error.__name__:
113+
if (
114+
metric.__name__
115+
== pl.metrics.functional.mean_squared_log_error.__name__
116+
):
109117
# MSLE should only be used in strictly positive targets. It is undefined otherwise
110118
_metric = metric(
111-
torch.clamp(y_hat[:, i], min=0), torch.clamp(y[:, i], min=0), **metric_params
119+
torch.clamp(y_hat[:, i], min=0),
120+
torch.clamp(y[:, i], min=0),
121+
**metric_params,
112122
)
113123
else:
114124
_metric = metric(y_hat[:, i], y[:, i], **metric_params)
@@ -139,33 +149,37 @@ def calculate_metrics(self, y, y_hat, tag):
139149
def forward(self, x: Dict):
140150
pass
141151

142-
def predict(self, x: Dict):
143-
return self.forward(x).get("logits")
152+
def predict(self, x: Dict, ret_model_output: bool = False):
153+
ret_value = self.forward(x)
154+
if ret_model_output:
155+
return ret_value.get("logits"), ret_value
156+
else:
157+
return ret_value.get("logits")
144158

145159
def training_step(self, batch, batch_idx):
146160
y = batch["target"]
147-
y_hat = self(batch)['logits']
161+
y_hat = self(batch)["logits"]
148162
loss = self.calculate_loss(y, y_hat, tag="train")
149163
_ = self.calculate_metrics(y, y_hat, tag="train")
150164
return loss
151165

152166
def validation_step(self, batch, batch_idx):
153167
y = batch["target"]
154-
y_hat = self(batch)['logits']
168+
y_hat = self(batch)["logits"]
155169
_ = self.calculate_loss(y, y_hat, tag="valid")
156170
_ = self.calculate_metrics(y, y_hat, tag="valid")
157171
return y_hat, y
158172

159173
def test_step(self, batch, batch_idx):
160174
y = batch["target"]
161-
y_hat = self(batch)['logits']
175+
y_hat = self(batch)["logits"]
162176
_ = self.calculate_loss(y, y_hat, tag="test")
163177
_ = self.calculate_metrics(y, y_hat, tag="test")
164178
return y_hat, y
165179

166180
def configure_optimizers(self):
167181
if self.custom_optimizer is None:
168-
#Loading from the config
182+
# Loading from the config
169183
try:
170184
self._optimizer = getattr(torch.optim, self.hparams.optimizer)
171185
opt = self._optimizer(
@@ -179,7 +193,7 @@ def configure_optimizers(self):
179193
)
180194
raise e
181195
else:
182-
#Loading from custom fit arguments
196+
# Loading from custom fit arguments
183197
self._optimizer = self.custom_optimizer
184198

185199
opt = self._optimizer(
@@ -215,15 +229,42 @@ def configure_optimizers(self):
215229
else:
216230
return opt
217231

232+
def create_plotly_histogram(self, arr, name, bin_dict=None):
233+
fig = go.Figure()
234+
for i in range(arr.shape[-1]):
235+
fig.add_trace(
236+
go.Histogram(
237+
x=arr[:, i],
238+
histnorm="probability",
239+
name=f"{name}_{i}",
240+
xbins=bin_dict, # dict(start=0.0, end=1.0, size=0.1), # bins used for histogram
241+
)
242+
)
243+
# Overlay both histograms
244+
fig.update_layout(
245+
barmode="overlay",
246+
legend=dict(
247+
orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1
248+
),
249+
)
250+
# Reduce opacity to see both histograms
251+
fig.update_traces(opacity=0.5)
252+
return fig
253+
218254
def validation_epoch_end(self, outputs) -> None:
219-
do_log_logits = self.hparams.log_logits and self.hparams.log_target == "wandb" and WANDB_INSTALLED
255+
do_log_logits = (
256+
self.hparams.log_logits
257+
and self.hparams.log_target == "wandb"
258+
and WANDB_INSTALLED
259+
)
220260
if do_log_logits:
221261
logits = [output[0] for output in outputs]
222-
flattened_logits = torch.flatten(torch.cat(logits))
262+
logits = torch.cat(logits).detach().cpu()
263+
fig = self.create_plotly_histogram(logits.unsqueeze(1), "logits")
223264
wandb.log(
224265
{
225-
"valid_logits": wandb.Histogram(flattened_logits.to("cpu")),
266+
"valid_logits": fig,
226267
"global_step": self.global_step,
227268
},
228-
commit=False
269+
commit=False,
229270
)

pytorch_tabular/models/mixture_density/config.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,21 @@ class MixtureDensityHeadConfig:
1818
Args:
1919
num_gaussian (int): Number of Gaussian Distributions in the mixture model. Defaults to 1
2020
n_samples (int): Number of samples to draw from the posterior to get prediction. Defaults to 100
21-
central_tendency (str): Which measure to use to get the point prediction. Choices are 'mean', 'median'. Defaults to `mean`
21+
central_tendency (str): Which measure to use to get the point prediction.
22+
Choices are 'mean', 'median'. Defaults to `mean`
23+
sigma_bias_flag (bool): Whether to have a bias term in the sigma layer. Defaults to False
24+
mu_bias_init (Optional[List]): To initialize the bias parameter of the mu layer to predefined cluster centers.
25+
Should be a list with the same length as number of gaussians in the mixture model.
26+
It is highly recommended to set the parameter to combat mode collapse. Defaults to None
27+
weight_regularization (Optional[int]): Whether to apply L1 or L2 Norm to the MDN layers.
28+
It is highly recommended to use this to avoid mode collapse. Choices are [1,2]. Defaults to L2
29+
lambda_sigma (Optional[float]): The regularization constant for weight regularization of sigma layer. Defaults to 0.1
30+
lambda_pi (Optional[float]): The regularization constant for weight regularization of pi layer. Defaults to 0.1
31+
lambda_mu (Optional[float]): The regularization constant for weight regularization of mu layer. Defaults to 0.1
32+
speedup_training (bool): Turning on this parameter does away with sampling during training which speeds up training,
33+
but also doesn't give you visibility on train metrics. Defaults to False
34+
log_debug_plot (bool): Turning on this parameter plots histograms of the mu, sigma, and pi layers in addition to the logits
35+
(if log_logits is turned on in experment config). Defaults to False
2236
2337
"""
2438

@@ -28,6 +42,45 @@ class MixtureDensityHeadConfig:
2842
"help": "Number of Gaussian Distributions in the mixture model. Defaults to 1",
2943
},
3044
)
45+
sigma_bias_flag: bool = field(
46+
default=False,
47+
metadata={
48+
"help": "Whether to have a bias term in the sigma layer. Defaults to False",
49+
},
50+
)
51+
mu_bias_init: Optional[List] = field(
52+
default=None,
53+
metadata={
54+
"help": "To initialize the bias parameter of the mu layer to predefined cluster centers. Should be a list with the same length as number of gaussians in the mixture model. It is highly recommended to set the parameter to combat mode collapse. Defaults to None",
55+
},
56+
)
57+
58+
weight_regularization: Optional[int] = field(
59+
default=2,
60+
metadata={
61+
"help": "Whether to apply L1 or L2 Norm to the MDN layers. Defaults to L2",
62+
"choices": [1, 2],
63+
},
64+
)
65+
66+
lambda_sigma: Optional[float] = field(
67+
default=0.1,
68+
metadata={
69+
"help": "The regularization constant for weight regularization of sigma layer. Defaults to 0.1",
70+
},
71+
)
72+
lambda_pi: Optional[float] = field(
73+
default=0.1,
74+
metadata={
75+
"help": "The regularization constant for weight regularization of pi layer. Defaults to 0.1",
76+
},
77+
)
78+
lambda_mu: Optional[float] = field(
79+
default=0,
80+
metadata={
81+
"help": "The regularization constant for weight regularization of mu layer. Defaults to 0",
82+
},
83+
)
3184
n_samples: int = field(
3285
default=100,
3386
metadata={
@@ -41,10 +94,16 @@ class MixtureDensityHeadConfig:
4194
"choices": ["mean", "median"],
4295
},
4396
)
44-
fast_training: bool = field(
97+
speedup_training: bool = field(
98+
default=False,
99+
metadata={
100+
"help": "Turning on this parameter does away with sampling during training which speeds up training, but also doesn't give you visibility on train metrics. Defaults to False",
101+
},
102+
)
103+
log_debug_plot: bool = field(
45104
default=False,
46105
metadata={
47-
"help": "Turning onthis parameter does away with sampling during training which speeds up training, but also doesn't give you visibility on training metrics. Defaults to True",
106+
"help": "Turning on this parameter plots histograms of the mu, sigma, and pi layers in addition to the logits(if log_logits is turned on in experment config). Defaults to False",
48107
},
49108
)
50109
_module_src: str = field(default="mixture_density")
@@ -87,7 +146,8 @@ class CategoryEmbeddingMDNConfig(CategoryEmbeddingModelConfig):
87146
"""
88147

89148
mdn_config: MixtureDensityHeadConfig = field(
90-
default=None, metadata={"help": "The config for defining the Mixed Density Network Head"}
149+
default=None,
150+
metadata={"help": "The config for defining the Mixed Density Network Head"},
91151
)
92152
_module_src: str = field(default="mixture_density")
93153
_model_name: str = field(default="CategoryEmbeddingMDN")
@@ -159,7 +219,8 @@ class NODEMDNConfig(NodeConfig):
159219
"""
160220

161221
mdn_config: MixtureDensityHeadConfig = field(
162-
default=None, metadata={"help": "The config for defining the Mixed Density Network Head"}
222+
default=None,
223+
metadata={"help": "The config for defining the Mixed Density Network Head"},
163224
)
164225
_module_src: str = field(default="mixture_density")
165226
_model_name: str = field(default="NODEMDN")

0 commit comments

Comments
 (0)