diff --git a/medsegpy/config.py b/medsegpy/config.py index 44d0e9ca..9f5ae137 100644 --- a/medsegpy/config.py +++ b/medsegpy/config.py @@ -106,6 +106,14 @@ class Config(object): EARLY_STOPPING_PATIENCE = 0 EARLY_STOPPING_CRITERION = "val_loss" + # Dropout rate + DROPOUT_RATE = 0.0 + + # Use Monte Carlo dropout to generate predictive uncertainty values + MC_DROPOUT = False + # Number of Monte Carlo dropout iterations + MC_DROPOUT_T = 100 + # Batch sizes TRAIN_BATCH_SIZE = 12 VALID_BATCH_SIZE = 35 @@ -589,6 +597,10 @@ def summary(self, additional_vars=None): "EARLY_STOPPING_PATIENCE" if self.USE_EARLY_STOPPING else "", "EARLY_STOPPING_CRITERION" if self.USE_EARLY_STOPPING else "", "", + "DROPOUT_RATE", + "MC_DROPOUT", + "MC_DROPOUT_T" if self.MC_DROPOUT else "" + "", "KERNEL_INITIALIZER", "SEED" if self.SEED else "", "" "INIT_WEIGHTS", diff --git a/medsegpy/data/data_loader.py b/medsegpy/data/data_loader.py index 3208192a..c60dc47d 100644 --- a/medsegpy/data/data_loader.py +++ b/medsegpy/data/data_loader.py @@ -337,6 +337,10 @@ def inference(self, model: Model, **kwargs): workers = kwargs.pop("workers", self._cfg.NUM_WORKERS) use_multiprocessing = kwargs.pop("use_multiprocessing", workers > 1) + + kwargs["mc_dropout"] = self._cfg.MC_DROPOUT + kwargs["mc_dropout_T"] = self._cfg.MC_DROPOUT_T + for scan_id in scan_ids: self._dataset_dicts = scan_to_dict_mapping[scan_id] @@ -353,6 +357,13 @@ def inference(self, model: Model, **kwargs): ) time_elapsed = time.perf_counter() - start + preds_mc_dropout = None + if isinstance(preds, dict): + if preds['preds_mc_dropout'] is not None: + preds_mc_dropout = np.squeeze(preds['preds_mc_dropout']).transpose((1, 2, 3, 0)) + + preds = preds['preds'] + x, y, preds = self._restructure_data((x, y, preds)) input = {"x": x, "scan_id": scan_id} @@ -363,7 +374,7 @@ def inference(self, model: Model, **kwargs): } input.update(scan_params) - output = {"y_pred": preds, "y_true": y, "time_elapsed": time_elapsed} + output = {"y_pred": preds, "y_mc_dropout":preds_mc_dropout, "y_true": y, "time_elapsed": time_elapsed} yield input, output diff --git a/medsegpy/evaluation/sem_seg_evaluation.py b/medsegpy/evaluation/sem_seg_evaluation.py index 6b8c192a..4780a520 100644 --- a/medsegpy/evaluation/sem_seg_evaluation.py +++ b/medsegpy/evaluation/sem_seg_evaluation.py @@ -134,6 +134,7 @@ def process(self, inputs, outputs): if includes_bg: y_true = output["y_true"][..., 1:] y_pred = output["y_pred"][..., 1:] + y_mc_dropout = None if output["y_mc_dropout"] is None else output["y_mc_dropout"][..., 1:] labels = labels[..., 1:] # if y_true.ndim == 3: # y_true = y_true[..., np.newaxis] @@ -141,6 +142,7 @@ def process(self, inputs, outputs): # labels = labels[..., np.newaxis] output["y_true"] = y_true output["y_pred"] = y_pred + output["y_mc_dropout"] = y_mc_dropout time_elapsed = output["time_elapsed"] if self.stream_evaluation: @@ -178,6 +180,8 @@ def eval_single_scan(self, input, output, labels, time_elapsed): with h5py.File(save_name, "w") as h5f: h5f.create_dataset("probs", data=output["y_pred"]) h5f.create_dataset("labels", data=labels) + if output["y_mc_dropout"] is not None: + h5f.create_dataset("mc_dropout", data=output["y_mc_dropout"]) def evaluate(self): """Evaluates popular medical segmentation metrics specified in config. diff --git a/medsegpy/modeling/meta_arch/unet.py b/medsegpy/modeling/meta_arch/unet.py index 07c70149..047e3a91 100644 --- a/medsegpy/modeling/meta_arch/unet.py +++ b/medsegpy/modeling/meta_arch/unet.py @@ -145,6 +145,7 @@ def build_model(self, input_tensor=None) -> Model: seed = cfg.SEED depth = cfg.DEPTH kernel_size = self.kernel_size + dropout_rate = cfg.DROPOUT_RATE self.use_attention = cfg.USE_ATTENTION self.use_deep_supervision = cfg.USE_DEEP_SUPERVISION @@ -178,7 +179,7 @@ def build_model(self, input_tensor=None) -> Model: num_conv=2, activation="relu", kernel_initializer=kernel_initializer, - dropout=0.0, + dropout=dropout_rate, ) # Maxpool until penultimate depth. @@ -220,7 +221,7 @@ def build_model(self, input_tensor=None) -> Model: num_conv=2, activation="relu", kernel_initializer=kernel_initializer, - dropout=0.0, + dropout=dropout_rate, ) if self.use_deep_supervision: diff --git a/medsegpy/modeling/model.py b/medsegpy/modeling/model.py index d514653d..5ac31e9f 100644 --- a/medsegpy/modeling/model.py +++ b/medsegpy/modeling/model.py @@ -6,6 +6,7 @@ from keras.models import Model as _Model from keras.utils.data_utils import GeneratorEnqueuer, OrderedEnqueuer from keras.utils.generic_utils import Progbar +import random from medsegpy.utils import env @@ -42,10 +43,12 @@ def inference_generator( max_queue_size=10, workers=1, use_multiprocessing=False, - verbose=0, + mc_dropout=False, + mc_dropout_T=100, + verbose=0 ): return self.inference_generator_static( - self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose + self, generator, steps, max_queue_size, workers, use_multiprocessing, mc_dropout, mc_dropout_T, verbose ) @classmethod @@ -57,7 +60,9 @@ def inference_generator_static( max_queue_size=10, workers=1, use_multiprocessing=False, - verbose=0, + mc_dropout=False, + mc_dropout_T=100, + verbose=0 ): """Generates predictions for the input samples from a data generator and returns inputs, ground truth, and predictions. @@ -115,6 +120,8 @@ def inference_generator_static( max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, + mc_dropout=mc_dropout, + mc_dropout_T=mc_dropout_T, verbose=verbose, ) else: @@ -252,9 +259,13 @@ def _inference_generator_tf2( max_queue_size=10, workers=1, use_multiprocessing=False, + mc_dropout=False, + mc_dropout_T=100 ): """Inference generator for TensorFlow 2.""" + random.seed(0) outputs = [] + outputs_mc_dropout = [] xs = [] ys = [] with model.distribute_strategy.scope(): @@ -295,14 +306,21 @@ def _inference_generator_tf2( batch_x, batch_y, batch_x_raw = _extract_inference_inputs(next(iterator)) # tmp_batch_outputs = predict_function(iterator) tmp_batch_outputs = model.predict(batch_x) + + + tmp_batch_outputs_mc_dropout = None + if mc_dropout: + tmp_batch_outputs_mc_dropout = np.stack([model(batch_x, training=True) for _ in range(mc_dropout_T)]) + if data_handler.should_sync: context.async_wait() # noqa: F821 batch_outputs = tmp_batch_outputs # No error, now safe to assign. + batch_outputs_mc_dropout = tmp_batch_outputs_mc_dropout if batch_x_raw is not None: batch_x = batch_x_raw for batch, running in zip( - [batch_x, batch_y, batch_outputs], [xs, ys, outputs] + [batch_x, batch_y, batch_outputs, batch_outputs_mc_dropout], [xs, ys, outputs, outputs_mc_dropout] ): nest.map_structure_up_to( batch, lambda x, batch_x: x.append(batch_x), running, batch @@ -318,7 +336,11 @@ def _inference_generator_tf2( all_xs = nest.map_structure_up_to(batch_x, np.concatenate, xs) all_ys = nest.map_structure_up_to(batch_y, np.concatenate, ys) all_outputs = nest.map_structure_up_to(batch_outputs, np.concatenate, outputs) - return all_xs, all_ys, all_outputs + all_outputs_mc_dropout = nest.map_structure_up_to(batch_outputs_mc_dropout, np.concatenate, outputs_mc_dropout) if mc_dropout else None + + outputs = {'preds': all_outputs, 'preds_mc_dropout': all_outputs_mc_dropout} + + return all_xs, all_ys, outputs # all_xs = nest.map_structure_up_to(batch_x, concat, xs) # all_ys = nest.map_structure_up_to(batch_y, concat, ys) diff --git a/tests/modeling/test_model.py b/tests/modeling/test_model.py new file mode 100644 index 00000000..e7b0125b --- /dev/null +++ b/tests/modeling/test_model.py @@ -0,0 +1,81 @@ +"""Test model output reproducability. + +These tests check that Monte Carlo dropout during inference produces +reproducible results. +""" + +import unittest +import numpy as np +import os +import h5py +import shutil +from fvcore.common.file_io import PathManager + +from medsegpy.config import UNetConfig +from medsegpy.modeling.meta_arch import build_model +from medsegpy.modeling.model import Model +from medsegpy.data import DefaultDataLoader + +class TestMCDropout(unittest.TestCase): + IMG_SIZE = (512, 512, 1) + NUM_CLASSES = 4 + FILE_PATH = "mock_data://temp_data/scan.h5" + + @classmethod + def setUpClass(cls): + np.random.seed(0) + img = np.random.rand(*cls.IMG_SIZE).astype(np.float32) + seg = (np.random.rand(*cls.IMG_SIZE, cls.NUM_CLASSES) >= 0.5).astype(np.uint8) + + file_path = PathManager.get_local_path(cls.FILE_PATH) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with h5py.File(file_path, "w") as f: + f.create_dataset("data", data=img) + f.create_dataset("volume", data=img) + f.create_dataset("seg", data=seg) + + @classmethod + def tearDownClass(cls): + file_path = PathManager.get_local_path(cls.FILE_PATH) + shutil.rmtree(os.path.dirname(file_path)) + + def get_dataset_dicts(self): + file_path = PathManager.get_local_path(self.FILE_PATH) + return [ + { + "file_name": file_path, + "sem_seg_file": file_path, + "scan_id": os.path.splitext(os.path.basename(file_path))[0], + "image_size": self.IMG_SIZE, + } + ] + + def test_inference_with_mc_dropout(self): + cfg = UNetConfig() + cfg.MC_DROPOUT = True + cfg.MC_DROPOUT_T = 10 + cfg.IMG_SIZE = self.IMG_SIZE + model = build_model(cfg) + + with h5py.File(PathManager.get_local_path(self.FILE_PATH), "r") as f: + volume = f["volume"][:] + mask = f["seg"][:] + dataset_dicts = self.get_dataset_dicts() + data_loader = DefaultDataLoader(cfg, dataset_dicts, is_test=True, shuffle=False) + + # Feed same data to inference generator twice + kwargs = dict() + kwargs["mc_dropout"] = data_loader._cfg.MC_DROPOUT + kwargs["mc_dropout_T"] = data_loader._cfg.MC_DROPOUT_T + x1, y1, preds1 = Model.inference_generator_static(model, data_loader, **kwargs) + x2, y2, preds2 = Model.inference_generator_static(model, data_loader, **kwargs) + + # Outputs should be the same + assert np.array_equal(x1, x2) + assert np.array_equal(y1, y2) + assert np.array_equal(preds1["preds"], preds2["preds"]) + assert np.array_equal(preds1["preds_mc_dropout"], preds2["preds_mc_dropout"]) + + +if "__name__" == "__main__": + unittest.main()