Skip to content

Commit cef6ae3

Browse files
authored
bugfix: default to multiclass accuracy (#146)
1 parent 4886978 commit cef6ae3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/pytorch_tabular/models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
for i, metric_params in enumerate(config.metrics_params):
105105
if "task" not in metric_params:
106106
# For classification task, output_dim == number of classses
107-
config.metrics_params[i]["task"] = "binary" if inferred_config.output_dim == 2 else "multiclass"
107+
config.metrics_params[i]["task"] = "multiclass"
108108
if "num_classes" not in metric_params:
109109
config.metrics_params[i]["num_classes"] = inferred_config.output_dim
110110

src/pytorch_tabular/tabular_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,15 +838,15 @@ def create_finetune_model(
838838
metrics = [torchmetrics.functional.accuracy]
839839
metrics_params = [
840840
{
841-
"task": "binary" if inferred_config.output_dim == 2 else "multiclass",
841+
"task": "multiclass",
842842
"num_classes": inferred_config.output_dim,
843843
}
844844
]
845845
else:
846846
for i, mp in enumerate(metrics_params):
847847
if "task" not in mp:
848848
# For classification task, output_dim == number of classses
849-
metrics_params[i]["task"] = "binary" if inferred_config.output_dim == 2 else "multiclass"
849+
metrics_params[i]["task"] = "multiclass"
850850
if "num_classes" not in mp:
851851
metrics_params[i]["num_classes"] = inferred_config.output_dim
852852
# Forming partial callables using metrics and metric params

0 commit comments

Comments
 (0)