Skip to content

Commit 306ad38

Browse files
committed
-- added pi, mu, sigma logging
1 parent 8c6df5d commit 306ad38

File tree

3 files changed

+443
-10
lines changed

3 files changed

+443
-10
lines changed
Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
from pytorch_tabular.models.node.config import NodeConfig
2+
from sklearn.datasets import fetch_california_housing
3+
from torch.utils import data
4+
from pytorch_tabular.config import (
5+
DataConfig,
6+
ExperimentConfig,
7+
ExperimentRunManager,
8+
ModelConfig,
9+
OptimizerConfig,
10+
TrainerConfig,
11+
)
12+
from pytorch_tabular.models.category_embedding.config import (
13+
CategoryEmbeddingModelConfig,
14+
)
15+
16+
from pytorch_tabular.models.mixture_density import (
17+
CategoryEmbeddingMDNConfig, MixtureDensityHeadConfig, NODEMDNConfig
18+
)
19+
from pytorch_tabular.models.node import NODEBackbone
20+
# from pytorch_tabular.models.deep_gmm import (
21+
# DeepGaussianMixtureModelConfig,
22+
# )
23+
from pytorch_tabular.models.category_embedding.category_embedding_model import (
24+
CategoryEmbeddingModel,
25+
)
26+
import pandas as pd
27+
from omegaconf import OmegaConf
28+
from pytorch_tabular.tabular_datamodule import TabularDatamodule
29+
from pytorch_tabular.tabular_model import TabularModel
30+
import pytorch_lightning as pl
31+
from sklearn.preprocessing import PowerTransformer
32+
33+
import torch
34+
import torch.nn as nn
35+
import torch.nn.functional as F
36+
from omegaconf import DictConfig
37+
from typing import Dict
38+
from dataclasses import dataclass, field
39+
from typing import List, Optional
40+
from pytorch_tabular.config import ModelConfig
41+
import pytorch_tabular.models as models
42+
from pytorch_tabular.models.node import utils as utils
43+
import logging
44+
import math
45+
import pytorch_lightning as pl
46+
from torch.autograd import Variable
47+
from pytorch_tabular.utils import _initialize_layers
48+
49+
50+
@dataclass
51+
class MultiStageModelConfig(ModelConfig):
52+
53+
num_layers: int = field(
54+
default=1,
55+
metadata={
56+
"help": "Number of Oblivious Decision Tree Layers in the Dense Architecture"
57+
},
58+
)
59+
num_trees: int = field(
60+
default=2048,
61+
metadata={"help": "Number of Oblivious Decision Trees in each layer"},
62+
)
63+
additional_tree_output_dim: int = field(
64+
default=3,
65+
metadata={
66+
"help": "The additional output dimensions which is only used to pass through different layers of the architectures. Only the first output_dim outputs will be used for prediction"
67+
},
68+
)
69+
depth: int = field(
70+
default=6,
71+
metadata={"help": "The depth of the individual Oblivious Decision Trees"},
72+
)
73+
choice_function: str = field(
74+
default="entmax15",
75+
metadata={
76+
"help": "Generates a sparse probability distribution to be used as feature weights(aka, soft feature selection)",
77+
"choices": ["entmax15", "sparsemax"],
78+
},
79+
)
80+
bin_function: str = field(
81+
default="entmoid15",
82+
metadata={
83+
"help": "Generates a sparse probability distribution to be used as tree leaf weights",
84+
"choices": ["entmoid15", "sparsemoid"],
85+
},
86+
)
87+
max_features: Optional[int] = field(
88+
default=None,
89+
metadata={
90+
"help": "If not None, sets a max limit on the number of features to be carried forward from layer to layer in the Dense Architecture"
91+
},
92+
)
93+
input_dropout: float = field(
94+
default=0.0,
95+
metadata={
96+
"help": "Dropout to be applied to the inputs between layers of the Dense Architecture"
97+
},
98+
)
99+
initialize_response: str = field(
100+
default="normal",
101+
metadata={
102+
"help": "Initializing the response variable in the Oblivious Decision Trees. By default, it is a standard normal distribution",
103+
"choices": ["normal", "uniform"],
104+
},
105+
)
106+
initialize_selection_logits: str = field(
107+
default="uniform",
108+
metadata={
109+
"help": "Initializing the feature selector. By default is a uniform distribution across the features",
110+
"choices": ["uniform", "normal"],
111+
},
112+
)
113+
threshold_init_beta: float = field(
114+
default=1.0,
115+
metadata={
116+
"help": """
117+
Used in the Data-aware initialization of thresholds where the threshold is initialized randomly
118+
(with a beta distribution) to feature values in the first batch.
119+
It initializes threshold to a q-th quantile of data points.
120+
where q ~ Beta(:threshold_init_beta:, :threshold_init_beta:)
121+
If this param is set to 1, initial thresholds will have the same distribution as data points
122+
If greater than 1 (e.g. 10), thresholds will be closer to median data value
123+
If less than 1 (e.g. 0.1), thresholds will approach min/max data values.
124+
"""
125+
},
126+
)
127+
threshold_init_cutoff: float = field(
128+
default=1.0,
129+
metadata={
130+
"help": """
131+
Used in the Data-aware initialization of scales(used in the scaling ODTs).
132+
It is initialized in such a way that all the samples in the first batch belong to the linear
133+
region of the entmoid/sparsemoid(bin-selectors) and thereby have non-zero gradients
134+
Threshold log-temperatures initializer, in (0, inf)
135+
By default(1.0), log-temperatures are initialized in such a way that all bin selectors
136+
end up in the linear region of sparse-sigmoid. The temperatures are then scaled by this parameter.
137+
Setting this value > 1.0 will result in some margin between data points and sparse-sigmoid cutoff value
138+
Setting this value < 1.0 will cause (1 - value) part of data points to end up in flat sparse-sigmoid region
139+
For instance, threshold_init_cutoff = 0.9 will set 10% points equal to 0.0 or 1.0
140+
Setting this value > 1.0 will result in a margin between data points and sparse-sigmoid cutoff value
141+
All points will be between (0.5 - 0.5 / threshold_init_cutoff) and (0.5 + 0.5 / threshold_init_cutoff)
142+
"""
143+
},
144+
)
145+
embed_categorical: bool = field(
146+
default=False,
147+
metadata={
148+
"help": "Flag to embed categorical columns using an Embedding Layer. If turned off, the categorical columns are encoded using LeaveOneOutEncoder"
149+
},
150+
)
151+
embedding_dims: Optional[List[int]] = field(
152+
default=None,
153+
metadata={
154+
"help": "The dimensions of the embedding for each categorical column as a list of tuples (cardinality, embedding_dim). If left empty, will infer using the cardinality of the categorical column using the rule min(50, (x + 1) // 2)"
155+
},
156+
)
157+
embedding_dropout: float = field(
158+
default=0.0,
159+
metadata={"help": "probability of an embedding element to be zeroed."},
160+
)
161+
from pytorch_tabular.models import BaseModel
162+
class MultiStageModel(BaseModel):
163+
def __init__(self, config: DictConfig, **kwargs):
164+
if config.embed_categorical:
165+
self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
166+
super().__init__(config, **kwargs)
167+
168+
def _build_network(self):
169+
if self.hparams.embed_categorical:
170+
self.embedding_layers = nn.ModuleList(
171+
[nn.Embedding(x, y) for x, y in self.hparams.embedding_dims]
172+
)
173+
if self.hparams.embedding_dropout != 0 and self.embedding_cat_dim != 0:
174+
self.embedding_dropout = nn.Dropout(self.hparams.embedding_dropout)
175+
self.hparams.node_input_dim = self.hparams.continuous_dim + self.embedding_cat_dim
176+
else:
177+
self.hparams.node_input_dim = self.hparams.continuous_dim + self.hparams.categorical_dim
178+
self.backbone = NODEBackbone(self.hparams)
179+
# average first n channels of every tree, where n is the number of output targets for regression
180+
# and number of classes for classification
181+
182+
def subset_clf(x):
183+
return x[..., : 2].mean(dim=-2)
184+
185+
def subset_rg(x):
186+
return x[..., 2: 4].mean(dim=-2)
187+
188+
self.clf_out = utils.Lambda(subset_clf)
189+
self.rg_out = utils.Lambda(subset_rg)
190+
self.classification_loss = nn.CrossEntropyLoss()
191+
192+
def unpack_input(self, x: Dict):
193+
continuous_data, categorical_data = x["continuous"], x["categorical"]
194+
if self.embedding_cat_dim != 0:
195+
x = []
196+
# for i, embedding_layer in enumerate(self.embedding_layers):
197+
# x.append(embedding_layer(categorical_data[:, i]))
198+
x = [
199+
embedding_layer(categorical_data[:, i])
200+
for i, embedding_layer in enumerate(self.embedding_layers)
201+
]
202+
x = torch.cat(x, 1)
203+
204+
if self.hparams.continuous_dim != 0:
205+
if self.hparams.batch_norm_continuous_input:
206+
continuous_data = self.normalizing_batch_norm(continuous_data)
207+
208+
if self.embedding_cat_dim != 0:
209+
x = torch.cat([x, continuous_data], 1)
210+
else:
211+
x = continuous_data
212+
return x
213+
214+
def forward(self, x: Dict):
215+
x = self.unpack_input(x)
216+
if self.hparams.embed_categorical:
217+
if self.hparams.embedding_dropout != 0 and self.embedding_cat_dim != 0:
218+
x = self.embedding_dropout(x)
219+
x = self.backbone(x)
220+
clf_logits = self.clf_out(x)
221+
clf_prob = nn.functional.gumbel_softmax(clf_logits, tau=1, dim=-1)
222+
223+
rg_out = self.rg_out(x)
224+
225+
y_hat = torch.sum(clf_prob * rg_out, dim=-1)
226+
if (self.hparams.task == "regression") and (
227+
self.hparams.target_range is not None
228+
):
229+
for i in range(self.hparams.output_dim):
230+
y_min, y_max = self.hparams.target_range[i]
231+
y_hat[:, i] = y_min + nn.Sigmoid()(y_hat[:, i]) * (y_max - y_min)
232+
return {"logits": y_hat, "clf_logits": clf_logits}
233+
234+
def training_step(self, batch, batch_idx):
235+
y = batch["target"]
236+
ret_value = self(batch)
237+
loss = self.calculate_loss(
238+
y, ret_value['clf_logits'], ret_value['logits'] , tag="train"
239+
)
240+
_ = self.calculate_metrics(y, ret_value['logits'], tag="train")
241+
return loss
242+
243+
def validation_step(self, batch, batch_idx):
244+
y = batch["target"]
245+
ret_value = self(batch)
246+
_ = self.calculate_loss(
247+
y, ret_value['clf_logits'], ret_value['logits'] , tag="valid"
248+
)
249+
_ = self.calculate_metrics(y, ret_value['logits'], tag="valid")
250+
return ret_value['logits'], y
251+
252+
def test_step(self, batch, batch_idx):
253+
y = batch["target"]
254+
ret_value = self(batch)
255+
_ = self.calculate_loss(
256+
y, ret_value['clf_logits'], ret_value['logits'] , tag="test"
257+
)
258+
_ = self.calculate_metrics(y, ret_value['logits'], tag="test")
259+
return ret_value['logits'], y
260+
261+
def calculate_loss(self, y, classification_logits, y_hat, tag):
262+
cl_loss = self.classification_loss(classification_logits.squeeze(), y[:,0].squeeze().long())
263+
rg_loss = self.loss(y_hat, y[:, 1])
264+
self.log(
265+
f"{tag}_classification_loss",
266+
cl_loss,
267+
on_epoch=True,
268+
on_step=False,
269+
logger=True,
270+
prog_bar=False,
271+
)
272+
self.log(
273+
f"{tag}_regression_loss",
274+
cl_loss,
275+
on_epoch=True,
276+
on_step=False,
277+
logger=True,
278+
prog_bar=False,
279+
)
280+
computed_loss = cl_loss + rg_loss
281+
self.log(
282+
f"{tag}_loss",
283+
computed_loss,
284+
on_epoch=(tag == "valid"),
285+
on_step=(tag == "train"),
286+
# on_step=False,
287+
logger=True,
288+
prog_bar=True,
289+
)
290+
return computed_loss
291+
292+
def calculate_metrics(self, y, y_hat, tag):
293+
for metric, metric_str, metric_params in zip(self.metrics, self.hparams.metrics, self.hparams.metrics_params):
294+
if metric.__name__==pl.metrics.functional.mean_squared_log_error.__name__:
295+
# MSLE should only be used in strictly positive targets. It is undefined otherwise
296+
metrics = metric(
297+
torch.clamp(y_hat, min=0), torch.clamp(y[:, 1], min=0), **metric_params
298+
)
299+
else:
300+
metrics = metric(y_hat, y[:, 1], **metric_params)
301+
self.log(
302+
f"{tag}_{metric_str}",
303+
metrics,
304+
on_epoch=True,
305+
on_step=False,
306+
logger=True,
307+
prog_bar=True,
308+
)
309+
return metrics
310+
311+
312+
dataset = fetch_california_housing(data_home="data", as_frame=True)
313+
dataset.frame["HouseAgeBin"] = pd.qcut(dataset.frame["HouseAge"], q=4)
314+
dataset.frame.HouseAgeBin = "age_" + dataset.frame.HouseAgeBin.cat.codes.astype(str)
315+
316+
test_idx = dataset.frame.sample(int(0.2 * len(dataset.frame)), random_state=42).index
317+
test = dataset.frame[dataset.frame.index.isin(test_idx)]
318+
train = dataset.frame[~dataset.frame.index.isin(test_idx)]
319+
320+
321+
322+
323+
epochs = 15
324+
batch_size = 128
325+
steps_per_epoch = int((len(train)//batch_size)*0.9)
326+
data_config = DataConfig(
327+
target=["HouseAgeBin"]+dataset.target_names,
328+
continuous_cols=[
329+
"AveRooms",
330+
"AveBedrms",
331+
"Population",
332+
"AveOccup",
333+
"Latitude",
334+
"Longitude",
335+
],
336+
# continuous_cols=[],
337+
categorical_cols=["HouseAgeBin"],
338+
continuous_feature_transform="quantile_uniform", # "yeo-johnson",
339+
normalize_continuous_features=True,
340+
)
341+
trainer_config = TrainerConfig(
342+
auto_lr_find=False, # Runs the LRFinder to automatically derive a learning rate
343+
batch_size=batch_size,
344+
max_epochs=epochs,
345+
early_stopping_patience = 5,
346+
checkpoints=None,
347+
# fast_dev_run=True,
348+
gpus=1, #index of the GPU to use. 0, means CPU
349+
)
350+
optimizer_config = OptimizerConfig(lr_scheduler="OneCycleLR", lr_scheduler_params={"max_lr":0.005, "epochs": epochs, "steps_per_epoch":steps_per_epoch})
351+
# optimizer_config = OptimizerConfig(lr_scheduler="ReduceLROnPlateau", lr_scheduler_params={"patience":3})
352+
model_config = MultiStageModelConfig(
353+
task="regression",
354+
num_layers=1, # Number of Dense Layers
355+
num_trees=2048, #Number of Trees in each layer
356+
depth=6, #Depth of each Tree
357+
embed_categorical=True, #If True, will use a learned embedding, else it will use LeaveOneOutEncoding for categorical columns
358+
learning_rate = 0.02,
359+
additional_tree_output_dim = 25,
360+
)
361+
# model_config.validate()
362+
# model_config = NodeConfig(task="regression", depth=2, embed_categorical=False)
363+
# trainer_config = TrainerConfig(checkpoints=None, max_epochs=5, gpus=1, profiler=None)
364+
# experiment_config = ExperimentConfig(
365+
# project_name="DeepGMM_test",
366+
# run_name="wand_debug",
367+
# log_target="wandb",
368+
# exp_watch="gradients",
369+
# log_logits=True
370+
# )
371+
# optimizer_config = OptimizerConfig()
372+
373+
tabular_model = TabularModel(
374+
data_config=data_config,
375+
model_config=model_config,
376+
optimizer_config=optimizer_config,
377+
trainer_config=trainer_config,
378+
# experiment_config=experiment_config,
379+
model_callable = MultiStageModel
380+
)
381+
tabular_model.fit(train=train, test=test)
382+
383+
result = tabular_model.evaluate(test)
384+
# print(result)
385+
# # print(result[0]['train_loss'])
386+
pred_df = tabular_model.predict(test, quantiles=[0.25])
387+
print(pred_df.head())
388+
# pred_df.to_csv("output/temp2.csv")

pytorch_tabular/models/base_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def calculate_loss(self, y, y_hat, tag):
101101

102102
def calculate_metrics(self, y, y_hat, tag):
103103
metrics = []
104-
y_hat = torch.clamp(y_hat, min=0)
105104
for metric, metric_str, metric_params in zip(self.metrics, self.hparams.metrics, self.hparams.metrics_params):
106105
if (self.hparams.task == "regression") and (self.hparams.output_dim > 1):
107106
_metrics = []

0 commit comments

Comments
 (0)