diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 5125d88a..d69b5b40 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -76,7 +76,7 @@ machineB.save() ### Training & Evaluation in Command Line -We provide a script in "medsegpy/train_net.py", that is made to train +We provide a script in "tools/train_net.py", that is made to train all the configs provided in medsegpy. You may want to use it as a reference to write your own training script for new research. diff --git a/medsegpy/config.py b/medsegpy/config.py index 44d0e9ca..3809d42d 100644 --- a/medsegpy/config.py +++ b/medsegpy/config.py @@ -70,6 +70,10 @@ class Config(object): # Class name for robust loss computation ROBUST_LOSS_NAME = "" ROBUST_LOSS_STEP_SIZE = 1e-1 + # Additonal loss functions to run during training + # [[(id_1, output_mode_1), class_weights_1], + # [(id_2, output_mode_2), class_weights_2] ... ] + LOSS_METRICS = [] # PIDS to include, None = all pids PIDS = None diff --git a/medsegpy/engine/trainer.py b/medsegpy/engine/trainer.py index e7948ece..690121d9 100644 --- a/medsegpy/engine/trainer.py +++ b/medsegpy/engine/trainer.py @@ -176,7 +176,14 @@ def _train_model(self): # TODO: Add more options for metrics. optimizer = solver.build_optimizer(cfg) loss_func = self.build_loss() - metrics = [lr_callback(optimizer), dice_loss] + + loss_metrics = [] + if len(cfg.LOSS_METRICS) > 0: + for loss_idx, loss_metric in enumerate(cfg.LOSS_METRICS): + new_metric = build_loss(cfg, build_additional_metric=True, additional_metric=loss_metric) + new_metric.name = f'{loss_metric[0][0]}_{loss_idx}' + loss_metrics.append(new_metric) + metrics = [lr_callback(optimizer), dice_loss] + loss_metrics callbacks = self.build_callbacks() if isinstance(loss_func, kc.Callback): diff --git a/medsegpy/loss/utils.py b/medsegpy/loss/utils.py index c2f5d69d..5406d3d3 100644 --- a/medsegpy/loss/utils.py +++ b/medsegpy/loss/utils.py @@ -65,6 +65,10 @@ def reduce_tensor(x, reduction="mean", axis=None, weights=None): use_weights = weights is not None if use_weights: x *= weights + if (reduction in ("none", None)) and (len(tf.where(weights==0)) == (len(weights) - 1)): + # if one of the weights = 1 and rest = 0, then only want loss of that single value + # need to scale by factor len(weights) because final reduction is a mean + return x * len(weights) if reduction == "mean" and use_weights: ndim = K.ndim(x) diff --git a/medsegpy/losses.py b/medsegpy/losses.py index 7ee28f02..ff0d43a0 100755 --- a/medsegpy/losses.py +++ b/medsegpy/losses.py @@ -17,6 +17,7 @@ AVG_DICE_LOSS = ("avg_dice", "sigmoid") AVG_DICE_LOSS_SOFTMAX = ("avg_dice", "softmax") AVG_DICE_NO_REDUCE = ("avg_dice_no_reduce", "sigmoid") +AVG_DICE_NO_REDUCE_SOFTMAX = ("avg_dice_no_reduce", "softmax") WEIGHTED_CROSS_ENTROPY_LOSS = ("weighted_cross_entropy", "softmax") WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS = ("weighted_cross_entropy_sigmoid", "sigmoid") @@ -36,6 +37,7 @@ "AVG_DICE_LOSS", "AVG_DICE_LOSS_SOFTMAX", "AVG_DICE_NO_REDUCE", + "AVG_DICE_NO_REDUCE_SOFTMAX", "WEIGHTED_CROSS_ENTROPY_LOSS", "WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS", "BINARY_CROSS_ENTROPY_LOSS", @@ -46,11 +48,30 @@ ] -def build_loss(cfg): - loss = cfg.LOSS +def build_loss(cfg, build_additional_metric=False, additional_metric: list = None): + if build_additional_metric is False: + loss = cfg.LOSS + robust_loss_cls = cfg.ROBUST_LOSS_NAME + robust_step_size = cfg.ROBUST_LOSS_STEP_SIZE + class_weights = cfg.CLASS_WEIGHTS + elif build_additional_metric is True: + loss = additional_metric[0] + # yaml giving trouble importing list of tuples - need to conver manually? + if type(loss) == list: + loss = tuple(loss) + class_weights = additional_metric[1] + # not supporting robust loss for additional metrics (for now). + robust_loss_cls = False + robust_step_size = None + num_classes = len(cfg.CATEGORIES) - robust_loss_cls = cfg.ROBUST_LOSS_NAME - robust_step_size = cfg.ROBUST_LOSS_STEP_SIZE + + # allow config to specify weights as integer indicating we only want + # to test one of the classes. + if type(class_weights) in (list, tuple): + pass + elif type(class_weights) is int: + class_weights = get_class_weights_from_int(class_weights, num_classes) if robust_loss_cls: reduction = "class" @@ -64,7 +85,7 @@ def build_loss(cfg): pass loss = get_training_loss( loss, - weights=cfg.CLASS_WEIGHTS, + weights=class_weights, # Remove computation on the background class. remove_background=cfg.INCLUDE_BACKGROUND, reduce=reduction, @@ -79,6 +100,12 @@ def build_loss(cfg): else: raise ValueError(f"{robust_loss_cls} not supported") +def get_class_weights_from_int(label, num_classes): + """Returns class_weights for an integer label.""" + class_weights = [0] * num_classes + class_weights[label] = 1 + return class_weights + # TODO (arjundd): Add ability to exclude specific indices from loss function. def get_training_loss_from_str(loss_str: str): @@ -91,6 +118,8 @@ def get_training_loss_from_str(loss_str: str): return AVG_DICE_LOSS elif loss_str == "AVG_DICE_NO_REDUCE": return AVG_DICE_NO_REDUCE + elif loss_str == "AVG_DICE_NO_REDUCE_SOFTMAX": + return AVG_DICE_NO_REDUCE_SOFTMAX elif loss_str == "WEIGHTED_CROSS_ENTROPY_LOSS": return WEIGHTED_CROSS_ENTROPY_LOSS elif loss_str == "WEIGHTED_CROSS_ENTROPY_SIGMOID_LOSS": @@ -134,6 +163,15 @@ def get_training_loss(loss, **kwargs): kwargs.pop("reduce", None) kwargs["reduction"] = "none" return DiceLoss(**kwargs) + elif loss == AVG_DICE_NO_REDUCE_SOFTMAX: + # Below is actually the same as the above, we could/should amalgamate? + kwargs.pop("reduce", None) + kwargs["reduction"] = "none" + # we don't need to add the softmax activation here - + # it should already be added here: + # (https://github.com/ad12/MedSegPy/blob/0c316baaf040c22d562940a198a0e48eef2d36a8/medsegpy/modeling/meta_arch/unet.py#L152) + # kwargs["activation"] = "softmax" + return DiceLoss(**kwargs) else: raise ValueError("Loss type not supported")