File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed
Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -104,7 +104,7 @@ def __init__(
104104 for i , metric_params in enumerate (config .metrics_params ):
105105 if "task" not in metric_params :
106106 # For classification task, output_dim == number of classses
107- config .metrics_params [i ]["task" ] = "binary" if inferred_config . output_dim == 2 else " multiclass"
107+ config .metrics_params [i ]["task" ] = "multiclass"
108108 if "num_classes" not in metric_params :
109109 config .metrics_params [i ]["num_classes" ] = inferred_config .output_dim
110110
Original file line number Diff line number Diff line change @@ -838,15 +838,15 @@ def create_finetune_model(
838838 metrics = [torchmetrics .functional .accuracy ]
839839 metrics_params = [
840840 {
841- "task" : "binary" if inferred_config . output_dim == 2 else " multiclass" ,
841+ "task" : "multiclass" ,
842842 "num_classes" : inferred_config .output_dim ,
843843 }
844844 ]
845845 else :
846846 for i , mp in enumerate (metrics_params ):
847847 if "task" not in mp :
848848 # For classification task, output_dim == number of classses
849- metrics_params [i ]["task" ] = "binary" if inferred_config . output_dim == 2 else " multiclass"
849+ metrics_params [i ]["task" ] = "multiclass"
850850 if "num_classes" not in mp :
851851 metrics_params [i ]["num_classes" ] = inferred_config .output_dim
852852 # Forming partial callables using metrics and metric params
You can’t perform that action at this time.
0 commit comments