77import os
88import warnings
99from collections import defaultdict
10+ from functools import partial
1011from pathlib import Path
1112from typing import Callable , Dict , Iterable , List , Optional , Tuple , Union
1213
@@ -370,7 +371,7 @@ def load_model(cls, dir: str, map_location=None, strict=True):
370371 model_args ["loss" ] = "MSELoss" # For compatibility. Not Used
371372 if custom_params .get ("custom_metrics" ) is not None :
372373 model_args ["metrics" ] = ["mean_squared_error" ] # For compatibility. Not Used
373- model_args ["metric_params " ] = [{}] # For compatibility. Not Used
374+ model_args ["metrics_params " ] = [{}] # For compatibility. Not Used
374375 if custom_params .get ("custom_optimizer" ) is not None :
375376 model_args ["optimizer" ] = "Adam" # For compatibility. Not Used
376377 if custom_params .get ("custom_optimizer_params" ) is not None :
@@ -815,6 +816,14 @@ def create_finetune_model(
815816 datamodule .target = config .target
816817 datamodule .batch_size = config .batch_size
817818 datamodule .seed = config .seed
819+ model_callable = _GenericModel
820+ inferred_config = self .datamodule .update_config (config )
821+ inferred_config = OmegaConf .structured (inferred_config )
822+ # Adding dummy attributes for compatibility. Not used because custom metrics are provided
823+ if not hasattr (config , "metrics" ):
824+ config .metrics = "dummy"
825+ if not hasattr (config , "metrics_params" ):
826+ config .metrics_params = {}
818827 if metrics is not None :
819828 assert len (metrics ) == len (metrics_params ), "Number of metrics and metrics_params should be same"
820829 metrics = [getattr (torchmetrics .functional , m ) if isinstance (m , str ) else m for m in metrics ]
@@ -827,11 +836,21 @@ def create_finetune_model(
827836 loss = loss if loss is not None else torch .nn .CrossEntropyLoss ()
828837 if metrics is None :
829838 metrics = [torchmetrics .functional .accuracy ]
830- metrics_params = [{}]
831-
832- model_callable = _GenericModel
833- inferred_config = self .datamodule .update_config (config )
834- inferred_config = OmegaConf .structured (inferred_config )
839+ metrics_params = [
840+ {
841+ "task" : "binary" if inferred_config .output_dim == 2 else "multiclass" ,
842+ "num_classes" : inferred_config .output_dim ,
843+ }
844+ ]
845+ else :
846+ for i , mp in enumerate (metrics_params ):
847+ if "task" not in mp :
848+ # For classification task, output_dim == number of classses
849+ metrics_params [i ]["task" ] = "binary" if inferred_config .output_dim == 2 else "multiclass"
850+ if "num_classes" not in mp :
851+ metrics_params [i ]["num_classes" ] = inferred_config .output_dim
852+ # Forming partial callables using metrics and metric params
853+ metrics = [partial (m , ** mp ) for m , mp in zip (metrics , metrics_params )]
835854 self .model .mode = "finetune"
836855 if learning_rate is not None :
837856 config .learning_rate = learning_rate
0 commit comments