diff --git a/CorpusCallosum/README.md b/CorpusCallosum/README.md new file mode 100644 index 000000000..9e945e35d --- /dev/null +++ b/CorpusCallosum/README.md @@ -0,0 +1,16 @@ +# Corpus Callosum Pipeline + +A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans. +Also segments the fornix, localizes the anterior and posterior commissure (AC and PC) and standardizes the orientation of the brain. + +For detailed documentation, please refer to: +- [Module Overview](../doc/overview/modules/CC.md): Detailed description of the pipeline, workflow, and analysis options. +- [Output Files](../doc/overview/OUTPUT_FILES.md#corpus-callosum-module): List of output files and their descriptions. + +## Quickstart + +```bash +python3 fastsurfer_cc.py --sd /path/to/fastsurfer/output --sid test-case --verbose +``` + +Gives all standard outputs. The corpus callosum morphometry can be found at `stats/callosum.CC.midslice.json` including 100 thickness measurements and the areas of sub-segments. diff --git a/CorpusCallosum/__init__.py b/CorpusCallosum/__init__.py new file mode 100644 index 000000000..63db725af --- /dev/null +++ b/CorpusCallosum/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + "data", + "segmentation", + "transforms", + "utils", +] diff --git a/CorpusCallosum/cc_visualization.py b/CorpusCallosum/cc_visualization.py new file mode 100644 index 000000000..a793ba440 --- /dev/null +++ b/CorpusCallosum/cc_visualization.py @@ -0,0 +1,254 @@ +import argparse +import sys +from pathlib import Path +from typing import Literal + +import numpy as np + +from CorpusCallosum.data.fsaverage_cc_template import load_fsaverage_cc_template +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.mesh import CCMesh +from FastSurferCNN.utils.logging import get_logger, setup_logging + +logger = get_logger(__name__) + + +def make_parser() -> argparse.ArgumentParser: + """Create a command line parser for the visualization pipeline.""" + parser = argparse.ArgumentParser(description="Visualize corpus callosum from template files.") + parser.add_argument( + "--template_dir", + type=str, + required=True, + help=( + "Path to a template directory containing per-slice files named " + "thickness_values_.txt, and optionally contour_.txt " + "and thickness_measurement_points_.txt. If contour_.txt " + "and thickness_measurement_points_.txt are not provided, " + "uses fsaverage template." + ), + metavar="TEMPLATE_DIR", + default=None, + ) + parser.add_argument("--output_dir", + type=str, + required=True, + help="Directory for output files. Writes: " + "cc_mesh.html - Interactive 3D mesh visualization (HTML file) " + "midslice_2d.png - 2D midslice visualization of the corpus callosum " + "cc_mesh.vtk - VTK mesh file format " + "cc_mesh.fssurf - FreeSurfer surface file " + "cc_mesh_overlay.curv - FreeSurfer curvature overlay file " + "cc_mesh_snap.png - Screenshot/snapshot of the 3D mesh (requires whippersnappy>=1.3.1)", + metavar="OUTPUT_DIR" + ) + parser.add_argument( + "--resolution", + type=float, + default=1.0, + help="Resolution in mm for the mesh.", + metavar="RESOLUTION" + ) + parser.add_argument( + "--smoothing_window", + type=int, + default=5, + help="Window size for smoothing the contour.", + metavar="SMOOTHING_WINDOW" + ) + parser.add_argument( + "--colormap", + type=str, + default="red_to_yellow", + choices=["red_to_blue", "blue_to_red", "red_to_yellow", "yellow_to_red"], + help="Colormap to use for thickness visualization, lower to higher values.", + ) + parser.add_argument( + "--color_range", + type=float, + nargs=2, + default=None, + metavar=("MIN", "MAX"), + required=False, + help="Specify the range for the colorbar (2 values: min max). Defaults to automatic choice. \ + (e.g. --color_range 0 10).", + ) + parser.add_argument( + "--legend", + type=str, + default="Thickness (mm)", + help="Legend for the colorbar.", + metavar="LEGEND") + parser.add_argument( + "--twoD", + action="store_true", + help="Generate 2D visualization instead of 3D mesh.", + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="Enable verbose (pass twice for debug-output).", + ) + return parser + + +def options_parse() -> argparse.Namespace: + """Parse command line arguments for the pipeline.""" + parser = make_parser() + args = parser.parse_args() + + # Create output directory if it doesn't exist + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + return args + + +def load_contours_from_template_dir( + template_dir: Path, resolution: float, smoothing_window: int +) -> list[CCContour]: + """Load all contours and thickness data from a template directory.""" + thickness_files = sorted(template_dir.glob("thickness_values_*.txt")) + if not thickness_files: + raise FileNotFoundError( + f"No thickness files found in template directory {template_dir}. " + "Expected files named thickness_values_.txt and " + "optionally contour_.txt and thickness_measurement_points_.txt." + ) + + fsaverage_contour = None + contours: list[CCContour] = [] + # First pass: collect all indices to determine the range + indices = [] + for thickness_file in thickness_files: + try: + idx = int(thickness_file.stem.split("_")[-1]) + indices.append(idx) + except ValueError: + # skip files that do not follow the expected naming + continue + + # Calculate z_positions centered around the middle slice + num_slices = len(indices) + middle_idx = num_slices // 2 + + for thickness_file in thickness_files: + try: + idx = int(thickness_file.stem.split("_")[-1]) + except ValueError: + # skip files that do not follow the expected naming + continue + + # Calculate z_position: use the index offset from middle, scaled by resolution + z_position = (idx - indices[middle_idx]) * resolution + + contour_file = template_dir / f"contour_{idx}.txt" + + if not contour_file.exists(): + # get length of thickness values + thickness_values = np.loadtxt(thickness_file, dtype=str) + # get the non nan thickness values (excluding header), so we know how many points to sample + num_thickness_values = np.sum(~np.isnan(np.array(thickness_values[1:],dtype=float))) + if fsaverage_contour is None: + fsaverage_contour = load_fsaverage_cc_template() + # create measurement points (points = 2 x levelpaths) according to number of thickness values + fsaverage_contour.create_levelpaths(num_points=num_thickness_values // 2, inplace=True) + current_contour = fsaverage_contour.copy() + current_contour.z_position = z_position + current_contour.load_thickness_values(thickness_file) + + else: + current_contour = CCContour.from_contour_file(contour_file, thickness_file, z_position=z_position) + + if smoothing_window > 0: + current_contour.smooth_contour(window_size=smoothing_window) + + current_contour.fill_thickness_values() + contours.append(current_contour) + + if not contours: + raise ValueError(f"No valid contours could be loaded from {template_dir}") + return contours + + +def main( + template_dir: str | Path, + output_dir: str | Path, + resolution: float = 1.0, + smoothing_window: int = 5, + colormap: str = "red_to_yellow", + color_range: tuple[float, float] | None = None, + legend: str | None = None, + twoD: bool = False, +) -> Literal[0] | str: + """Visualize corpus callosum templates in 2D or 3D.""" + output_dir = Path(output_dir) + color_range = tuple(color_range) if color_range is not None else None + + contours = load_contours_from_template_dir( + Path(template_dir), resolution=resolution, smoothing_window=smoothing_window, + ) + + # 2D visualization + mid_contour = contours[len(contours) // 2] + + # for now, we only support thickness visualization, this is preparing to plot also p-values and icc values + mode = "thickness" + logger.info(f"Writing output to {output_dir / 'cc_thickness_2d.png'}") + + if mode == "thickness": + raw_thickness_values = mid_contour.thickness_values[~np.isnan(mid_contour.thickness_values)] + # values are duplicated because they have two measurement points per levelpath + raw_thickness_values = raw_thickness_values[len(raw_thickness_values) // 2:] + mid_contour.plot_contour_colorfill( + plot_values=raw_thickness_values, + title=None, + save_path=str(output_dir / "cc_thickness_2d.png"), + colorbar=True, + mode=mode + ) + if twoD: + return 0 + + # 3D visualization + cc_mesh = CCMesh.from_contours(contours, smooth=0) + + plot_kwargs = dict( + colormap=colormap, + color_range=color_range, + thickness_overlay=True, + legend=legend or "", + ) + cc_mesh.plot_mesh(**plot_kwargs) + cc_mesh.plot_mesh(output_path=str(output_dir / "cc_mesh.html"), **plot_kwargs) + + logger.info(f"Writing vtk file to {output_dir / 'cc_mesh.vtk'}") + cc_mesh.write_vtk(str(output_dir / "cc_mesh.vtk")) + logger.info(f"Writing freesurfer surface file to {output_dir / 'cc_mesh.fssurf'}") + cc_mesh.write_fssurf(str(output_dir / "cc_mesh.fssurf")) + logger.info(f"Writing freesurfer overlay file to {output_dir / 'cc_mesh_overlay.curv'}") + cc_mesh.write_morph_data(str(output_dir / "cc_mesh_overlay.curv")) + try: + cc_mesh.snap_cc_picture(str(output_dir / "cc_mesh_snap.png")) + logger.info(f"Writing 3D snapshot image to {output_dir / 'cc_mesh_snap.png'}") + except RuntimeError: + logger.warning("The cc_visualization script requires whippersnappy>=1.3.1 to makes screenshots, install with " + "`pip install whippersnappy>=1.3.1` !") + return 0 + +if __name__ == "__main__": + options = options_parse() + + # Set up logging if verbose mode is enabled + setup_logging(None, options.verbose) # Log to stdout only + + sys.exit(main( + template_dir=options.template_dir, + output_dir=options.output_dir, + resolution=options.resolution, + smoothing_window=options.smoothing_window, + colormap=options.colormap, + color_range=options.color_range, + legend=options.legend, + twoD=options.twoD, + )) diff --git a/CorpusCallosum/config/checkpoint_paths.yaml b/CorpusCallosum/config/checkpoint_paths.yaml new file mode 100644 index 000000000..ca78b7da2 --- /dev/null +++ b/CorpusCallosum/config/checkpoint_paths.yaml @@ -0,0 +1,7 @@ +url: +- "https://zenodo.org/records/17141933/files" +- "https://b2share.fz-juelich.de/api/files/e4eb699c-ba68-4470-9f3d-89ceeee1a334" + +checkpoint: + segmentation: "checkpoints/FastSurferCC_segmentation_v1.0.0.pkl" + localization: "checkpoints/FastSurferCC_localization_v1.0.0.pkl" diff --git a/CorpusCallosum/data/__init__.py b/CorpusCallosum/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/data/constants.py b/CorpusCallosum/data/constants.py new file mode 100644 index 000000000..745809313 --- /dev/null +++ b/CorpusCallosum/data/constants.py @@ -0,0 +1,57 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +### Constants +WEIGHTS_PATH = FASTSURFER_ROOT / "checkpoints" +FSAVERAGE_CENTROIDS_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "data" / "fsaverage_centroids.json" +# Contains both affine and header +FSAVERAGE_DATA_PATH = FASTSURFER_ROOT / "CorpusCallosum" / "data" / "fsaverage_data.json" +FSAVERAGE_MIDDLE = 128 # Middle slice index in fsaverage space +CC_LABEL = 192 # Label value for corpus callosum in segmentation +FORNIX_LABEL = 250 # Label value for fornix in segmentation +THIRD_VENTRICLE_LABEL = 4 # Label value for third ventricle in segmentation +SUBSEGMENT_LABELS = [251, 252, 253, 254, 255] # labels for subsegments in segmentation + + +DEFAULT_INPUT_PATHS = { + "conf_name": "mri/orig.mgz", + "aseg_name": "mri/aparc.DKTatlas+aseg.deep.mgz", +} + +DEFAULT_OUTPUT_PATHS = { + ## images + "upright_volume": None, # orig.mgz mapped to upright space + ## segmentations + "segmentation": "mri/callosum.CC.upright.mgz", # corpus callosum segmentation in upright space + "segmentation_in_orig": "mri/callosum.CC.orig.mgz", # cc segmentation in input segmentations space + "softlabels_cc": "mri/callosum.CC.soft.mgz", # cc softlabels in upright space + "softlabels_fn": "mri/fornix.CC.soft.mgz", # fornix softlabels in upright space + "softlabels_background": "mri/background.CC.soft.mgz", # background softlabels in upright space + ## stats + "cc_markers": "stats/callosum.CC.midslice.json", # cc metrics for middle slice + "cc_measures": "stats/callosum.CC.all_slices.json", # cc metrics for all slices + ## transforms + "upright_lta": "mri/transforms/cc_up.lta", # lta transform from orig to upright space + "orient_volume_lta": "mri/transforms/orient_volume.lta", # lta transform from orig to upright+acpc corrected space + ## qc + "qc_image": None, #"callosum.png", # debug image of cc contours + "thickness_image": None, # "callosum.thickness.png", # whippersnappy 3D image of cc thickness + "cc_html": None, # "corpus_callosum.html", # plotly cc visualization + ## surface + "cc_surf": "surf/callosum.surf", # cc surface file + "cc_thickness_overlay": "surf/callosum.thickness.w", # cc surface overlay file + "cc_surf_vtk": "surf/callosum.vtk", # vtk file of cc mesh +} \ No newline at end of file diff --git a/CorpusCallosum/data/fsaverage_cc_template.py b/CorpusCallosum/data/fsaverage_cc_template.py new file mode 100644 index 000000000..0b67b767d --- /dev/null +++ b/CorpusCallosum/data/fsaverage_cc_template.py @@ -0,0 +1,153 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pathlib import Path +from typing import cast + +import nibabel as nib +import numpy as np +from scipy import ndimage + +from CorpusCallosum.data import constants +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.postprocessing import recon_cc_surf_measure +from FastSurferCNN.utils import nibabelImage +from FastSurferCNN.utils.brainvolstats import mask_in_array + +FSAVERAGE_PC_COORDINATE = np.array([131, 99]) +FSAVERAGE_AC_COORDINATE = np.array([135, 130]) + + +def smooth_contour(contour: tuple[np.ndarray, np.ndarray], window_size: int = 5) -> tuple[np.ndarray, np.ndarray]: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + contour : tuple of arrays + The contour coordinates (x, y). + window_size : int + Size of the smoothing window. + + Returns + ------- + tuple of arrays + The smoothed contour coordinates (x, y). + + """ + x, y = contour + + # Ensure the window size is odd + if window_size % 2 == 0: + window_size += 1 + + # Create a padded version of the arrays to handle the edges + x_padded = np.pad(x, (window_size//2, window_size//2), mode='wrap') + y_padded = np.pad(y, (window_size//2, window_size//2), mode='wrap') + + # Apply moving average + x_smoothed = np.zeros_like(x) + y_smoothed = np.zeros_like(y) + + for i in range(len(x)): + x_smoothed[i] = np.mean(x_padded[i:i+window_size]) + y_smoothed[i] = np.mean(y_padded[i:i+window_size]) + + return (x_smoothed, y_smoothed) + + +def load_fsaverage_cc_template() -> CCContour: + """Load and process the fsaverage corpus callosum template. + + This function loads the fsaverage segmentation from FreeSurfer's data directory, + extracts the corpus callosum mask, and processes it to create a smooth template. + + Returns + ------- + CCContour + Object with all the contour information including: + - contour : tuple[np.ndarray, np.ndarray] : x and y coordinates of the contour points. + - anterior_endpoint_idx : np.ndarray : Index of the anterior endpoint. + - posterior_endpoint_idx : np.ndarray : Index of the posterior endpoint. + + Raises + ------ + OSError + If FREESURFER_HOME environment variable is not set correctly. + + """ + # smooth outside contour + # Apply smoothing to the outside contour using a moving average + + try: + freesurfer_home = Path(os.environ['FREESURFER_HOME']) + except KeyError as err: + raise OSError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: " + f"{freesurfer_home}, either provide your own template or set the " + f"FREESURFER_HOME environment variable") from err + + fsaverage_seg_path = freesurfer_home / 'subjects' / 'fsaverage' / 'mri' / 'aparc+aseg.mgz' + fsaverage_seg = cast(nibabelImage, nib.load(fsaverage_seg_path)) + segmentation = np.asarray(fsaverage_seg.dataobj) + + midslice = segmentation.shape[0]//2 +1 + + cc_mask = mask_in_array(segmentation[midslice], constants.SUBSEGMENT_LABELS) + + # Smooth the CC mask to reduce noise and irregularities + + # Apply binary closing to fill small holes + cc_mask_smoothed = ndimage.binary_closing(cc_mask, structure=np.ones((3, 3))) + + # Apply binary opening to remove small isolated pixels + cc_mask_smoothed = ndimage.binary_opening(cc_mask_smoothed, structure=np.ones((2, 2))) + + # Apply Gaussian smoothing and threshold to get a binary mask again + cc_mask_smoothed = ndimage.gaussian_filter(cc_mask_smoothed.astype(float), sigma=0.8) + cc_mask_smoothed = cc_mask_smoothed > 0.5 + + # Use the smoothed mask for further processing + cc_mask = cc_mask_smoothed.astype(int) * 192 + + _, _fsaverage_contour = recon_cc_surf_measure( + segmentation=cc_mask[None], + slice_idx=0, + ac_coords_vox=FSAVERAGE_AC_COORDINATE, + pc_coords_vox=FSAVERAGE_PC_COORDINATE, + slice_lia_vox2midslice_ras=fsaverage_seg.affine, + num_thickness_points=100, + subdivisions=[1/6, 1/2, 2/3, 3/4], + subdivision_method="shape", + contour_smoothing=5, + ) + outside_contour = _fsaverage_contour.points.T + anterior_endpoint_idx, posterior_endpoint_idx = _fsaverage_contour.endpoint_idxs + + # make sure the CC stays in shape despite smoothing by moving endpoints outwards + outside_contour[0, anterior_endpoint_idx] -= 55 + outside_contour[0, posterior_endpoint_idx] += 30 + + # Apply smoothing to the outside contour + outside_contour_smoothed = smooth_contour(outside_contour, window_size=11) + outside_contour_smoothed = smooth_contour(outside_contour_smoothed, window_size=15) + outside_contour_smoothed = smooth_contour(outside_contour_smoothed, window_size=30) + outside_contour = outside_contour_smoothed + + fsaverage_contour = CCContour(np.array(outside_contour).T, + np.zeros(len(outside_contour[0])), + endpoint_idxs=(anterior_endpoint_idx, posterior_endpoint_idx), + z_position=0.0) + + + return fsaverage_contour diff --git a/CorpusCallosum/data/fsaverage_centroids.json b/CorpusCallosum/data/fsaverage_centroids.json new file mode 100644 index 000000000..bccf1189d --- /dev/null +++ b/CorpusCallosum/data/fsaverage_centroids.json @@ -0,0 +1,217 @@ +{ + "2": [ + -27.242888317659038, + -22.210776052870685, + 18.546657917012894 + ], + "3": [ + -32.18990180647074, + -16.863336561239265, + 16.015058654310195 + ], + "4": [ + -14.455663189269757, + -13.693461251862885, + 13.7136736214605 + ], + "5": [ + -33.906934306569354, + -22.284671532846716, + -15.821167883211672 + ], + "7": [ + -17.305372931308085, + -53.43157258369229, + -36.01715408448575 + ], + "8": [ + -22.265822784810126, + -64.36629649763144, + -37.674831094198964 + ], + "10": [ + -11.752497096399537, + -19.87584204413473, + 5.165737514518 + ], + "11": [ + -15.034188034188048, + 9.437551695616207, + 6.913427074717404 + ], + "12": [ + -26.366197183098592, + -0.15686274509803866, + -2.091549295774655 + ], + "13": [ + -20.91671388101983, + -5.188668555240795, + -2.4107648725212414 + ], + "14": [ + 0.5832045337454872, + -11.11695002575992, + -3.9433281813498127 + ], + "15": [ + 0.5413500223513665, + -46.56236030397854, + -33.21814930710772 + ], + "16": [ + 0.8273686582297444, + -31.946261594502232, + -31.003755304367417 + ], + "17": [ + -26.088480154888686, + -24.429622458857693, + -15.148886737657307 + ], + "18": [ + -23.90932509015971, + -7.339515713549716, + -20.63575476558475 + ], + "24": [ + 0.6026785714285694, + -20.70535714285714, + 8.040736607142861 + ], + "26": [ + -9.629820051413873, + 10.960154241645256, + -8.786632390745496 + ], + "28": [ + -11.456631660832358, + -16.84694671334111, + -10.32691559704395 + ], + "30": [ + -28.545454545454533, + -3.200000000000003, + -10.181818181818187 + ], + "31": [ + -12.502610966057432, + -12.218015665796344, + 6.30548302872063 + ], + "41": [ + 27.68021284305685, + -21.297671313867227, + 18.84475807220643 + ], + "42": [ + 32.70257488842361, + -15.910019860438453, + 16.482307738602415 + ], + "43": [ + 15.18157827962446, + -13.241715300685101, + 14.257802588175593 + ], + "44": [ + 33.10191082802548, + -17.921443736730367, + -16.980891719745216 + ], + "46": [ + 19.070892410341955, + -53.51368564713019, + -35.67336416710896 + ], + "47": [ + 23.65288732176549, + -64.41682904951904, + -37.19518418854969 + ], + "49": [ + 12.493538246594483, + -19.225986727209218, + 5.663872394923743 + ], + "50": [ + 16.15939771547248, + 9.458463136033231, + 8.239096573208727 + ], + "51": [ + 26.94455762514552, + 0.5477299185099014, + -2.249126891734562 + ], + "52": [ + 22.105321507760536, + -4.939024390243901, + -1.9539911308204125 + ], + "53": [ + 27.74364210135512, + -23.379431965843693, + -14.994987933914985 + ], + "54": [ + 24.942549371633746, + -6.010771992818675, + -20.737881508079 + ], + "58": [ + 9.986789960369876, + 10.424042272126826, + -8.705416116248358 + ], + "60": [ + 12.434200157604408, + -16.41252955082743, + -10.056737588652481 + ], + "62": [ + 30.558139534883722, + -2.581395348837205, + -10.441860465116292 + ], + "63": [ + 12.008567931456554, + -11.022031823745408, + 7.3671970624235 + ], + "77": [ + -13.714285714285722, + -15.714285714285708, + 0.9285714285714306 + ], + "85": [ + 1.466019417475735, + -0.2038834951456323, + -18.466019417475735 + ], + "251": [ + 0.5403535741737073, + -35.800153727901616, + 16.784780937740194 + ], + "252": [ + 0.6063829787234027, + -18.29361702127659, + 24.748936170212772 + ], + "253": [ + 0.5847299813780324, + -2.424581005586589, + 25.815642458100555 + ], + "254": [ + 0.7008849557522154, + 11.998230088495575, + 20.40530973451328 + ], + "255": [ + 0.8761467889908232, + 24.612844036697254, + 5.411009174311928 + ] +} \ No newline at end of file diff --git a/CorpusCallosum/data/fsaverage_data.json b/CorpusCallosum/data/fsaverage_data.json new file mode 100644 index 000000000..42cf562b1 --- /dev/null +++ b/CorpusCallosum/data/fsaverage_data.json @@ -0,0 +1,62 @@ +{ + "affine": [ + [ + -1.0, + 0.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 1.0, + -128.0 + ], + [ + 0.0, + -1.0, + 0.0, + 128.0 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ] + ], + "header": { + "dims": [ + 256, + 256, + 256 + ], + "delta": [ + 1.0, + 1.0, + 1.0 + ], + "Mdc": [ + [ + -1.0, + 0.0, + 0.0 + ], + [ + 0.0, + 0.0, + 10000000000.0 + ], + [ + 0.0, + -10000000000.0, + 0.0 + ] + ], + "Pxyz_c": [ + 128.0, + -128.0, + 128.0 + ] + } +} \ No newline at end of file diff --git a/CorpusCallosum/data/generate_fsaverage_centroids.py b/CorpusCallosum/data/generate_fsaverage_centroids.py new file mode 100644 index 000000000..b1ef7b19a --- /dev/null +++ b/CorpusCallosum/data/generate_fsaverage_centroids.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Script to generate static fsaverage centroids file. + +This script extracts centroids from the fsaverage template segmentation +and saves them to a JSON file for fast loading during pipeline execution. +Run this script once to generate the centroids file. +""" + +import json +import os +from pathlib import Path + +import nibabel as nib +import numpy as np +from read_write import calc_ras_centroids_from_seg, convert_numpy_to_json_serializable + +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) + + +def main() -> None: + """Generate and save fsaverage centroids to a static file. + + This script extracts centroids from the fsaverage template segmentation + and saves them to a JSON file for fast loading during pipeline execution. + + The script performs the following steps: + 1. Load fsaverage segmentation from FreeSurfer directory + 2. Extract centroids for all anatomical structures + 3. Save centroids to JSON file + 4. Extract and save affine matrix and header fields + + Raises + ------ + OSError + If FREESURFER_HOME environment variable is not set or invalid + FileNotFoundError + If required fsaverage files are not found + + Notes + ----- + The script saves two files: + - fsaverage_centroids.json : Contains centroids for each anatomical structure + - fsaverage_data.json : Contains affine matrix and header information + """ + + # Get fsaverage path from FreeSurfer environment + try: + fs_home = Path(os.environ['FREESURFER_HOME']) + if not fs_home.exists(): + raise OSError(f"FREESURFER_HOME environment variable is not set correctly or does not exist: {fs_home}") + + fsaverage_path = fs_home / 'subjects' / 'fsaverage' + if not fsaverage_path.exists(): + raise OSError(f"fsaverage path does not exist: {fsaverage_path}") + + fsaverage_aseg_path = fsaverage_path / 'mri' / 'aseg.mgz' + if not fsaverage_aseg_path.exists(): + raise FileNotFoundError(f"fsaverage aseg file does not exist: {fsaverage_aseg_path}") + + except KeyError as err: + raise OSError("FREESURFER_HOME environment variable is not set") from err + + logger.info(f"Loading fsaverage segmentation from: {fsaverage_aseg_path}") + + # Load fsaverage segmentation + fsaverage_nib = nib.load(fsaverage_aseg_path) + + # Extract centroids + logger.info("Extracting centroids from fsaverage...") + centroids_dst = calc_ras_centroids_from_seg(fsaverage_nib) + + logger.info(f"Found {len(centroids_dst)} anatomical structures with centroids") + + # Convert to JSON-serializable format + centroids_serializable = convert_numpy_to_json_serializable(centroids_dst) + + # Save centroids to JSON file + centroids_output_path = Path(__file__).parent / "fsaverage_centroids.json" + logger.info(f"Saving fsaverage centroids to {centroids_output_path}") + with open(centroids_output_path, 'w') as f: + json.dump(centroids_serializable, f, indent=2) + + logger.info(f"Fsaverage centroids saved to: {centroids_output_path}") + logger.info(f"Centroids file size: {centroids_output_path.stat().st_size} bytes") + + # Extract and save fsaverage affine matrix and header fields + logger.info("Extracting fsaverage affine matrix and header fields...") + fsaverage_affine = fsaverage_nib.affine.astype(float) # Convert to float for JSON serialization + + # Extract header fields needed for LTA + header = fsaverage_nib.header + dims = [int(x) for x in header.get_data_shape()[:3]] # Convert to int for JSON serialization + delta = [float(x) for x in header.get_zooms()[:3]] # Convert to float for JSON serialization + vox2ras = header.get_vox2ras() + + # Direction cosines matrix (Mdc) - extract rotation part without scaling + delta_diag = np.diag(delta) + # Avoid division by zero by using a small epsilon for zero values + delta_safe = np.where(delta_diag == 0, 1e-10, delta_diag) + Mdc = (vox2ras[:3, :3] / delta_safe).astype(float) # Convert to float for JSON serialization + + Pxyz_c = vox2ras[:3, 3].astype(float) # Convert to float for JSON serialization + + # Combine affine and header data + combined_data = { + "affine": fsaverage_affine.tolist(), # Convert numpy array to list for JSON serialization + "vox2ras_tkr": fsaverage_nib.header.get_vox2ras_tkr().tolist(), + "header": { + "dims": dims, + "delta": delta, + "Mdc": Mdc.tolist(), # Convert numpy array to list for JSON serialization + "Pxyz_c": Pxyz_c.tolist() # Convert numpy array to list for JSON serialization + } + } + + # Convert the entire structure to JSON-serializable format to handle any remaining numpy types + combined_data_serializable = convert_numpy_to_json_serializable(combined_data) + + # Save combined data to JSON file + combined_output_path = Path(__file__).parent / "fsaverage_data.json" + logger.info(f"Saving fsaverage affine and header data to {combined_output_path}") + with open(combined_output_path, 'w') as f: + json.dump(combined_data_serializable, f, indent=2) + + logger.info(f"Fsaverage affine and header data saved to: {combined_output_path}") + logger.info(f"Combined file size: {combined_output_path.stat().st_size} bytes") + logger.info(f"Affine matrix shape: {fsaverage_affine.shape}") + logger.info(f"Header dims: {dims}, delta: {delta}") + + # Print some statistics + label_ids = list(centroids_dst.keys()) + logger.info(f"Label IDs range: {min(label_ids)} to {max(label_ids)}") + logger.info("Sample centroids:") + for label_id in sorted(label_ids)[:5]: + centroid = centroids_dst[label_id] + logger.info(f" Label {label_id}: [{centroid[0]:.2f}, {centroid[1]:.2f}, {centroid[2]:.2f}]") + + logger.info("Fsaverage affine matrix:") + logger.info(fsaverage_affine) + + logger.info("Fsaverage header fields:") + logger.info(f" dims: {dims}") + logger.info(f" delta: {delta}") + logger.info(f" Mdc shape: {Mdc.shape}") + logger.info(f" Pxyz_c: {Pxyz_c}") + logger.info("Combined data structure created successfully") + + +if __name__ == "__main__": + main() diff --git a/CorpusCallosum/data/read_write.py b/CorpusCallosum/data/read_write.py new file mode 100644 index 000000000..0c20c4973 --- /dev/null +++ b/CorpusCallosum/data/read_write.py @@ -0,0 +1,221 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +from typing import TypedDict + +import numpy as np +from numpy import typing as npt + +import FastSurferCNN.utils.logging as logging +from FastSurferCNN.utils import AffineMatrix4x4, RotationMatrix3x3, Vector3d, nibabelImage +from FastSurferCNN.utils.parallel import thread_executor + +logger = logging.get_logger(__name__) + + +class MGHHeaderDict(TypedDict): + """A dictionary with the four required fields of a MGH Header""" + dims: Vector3d + delta: Vector3d + Mdc: RotationMatrix3x3 + Pxyz_c: Vector3d + + +def calc_ras_centroids_from_seg(seg_img: nibabelImage, label_ids: list[int] | None = None) \ + -> dict[int, np.ndarray | None]: + """Get centroids of segmentation labels in RAS coordinates, accepts any affine/data layout. + + Parameters + ---------- + seg_img : nibabel.analyze.SpatialImage + Input segmentation image. + label_ids : list[int], optional + List of label IDs to extract centroids for. If None, extracts all non-zero labels. + + Returns + ------- + dict[int, np.ndarray | None] + A dict mapping label IDs to their centroids (x,y,z) in RAS coordinates, None if label did not exist. + """ + # Get segmentation data and affine + seg_data: npt.NDArray[np.integer] = np.asarray(seg_img.dataobj) + vox2ras: AffineMatrix4x4 = seg_img.affine + + # Get unique labels + if label_ids is None: + labels = np.unique(seg_data) + labels = labels[labels > 0] # Exclude background + else: + labels = label_ids + + def _each_label(label): + # Get voxel indices for this label + if np.any(mask := seg_data == label): + # Calculate centroid in voxel space + vox_centroid = np.mean(np.where(mask), axis=1, dtype=float) + + # Convert to homogeneous coordinates + vox_centroid_hom = np.append(vox_centroid, 1) + + # Transform to RAS coordinates and return without homogeneous coordinate + return int(label), (vox2ras @ vox_centroid_hom)[:3] + else: + return int(label), None + + return dict(thread_executor().map(_each_label, labels)) + + +def convert_numpy_to_json_serializable(obj: object) -> object: + """Convert numpy types to JSON serializable types. + + Parameters + ---------- + obj : dict, list, array, number, serializable + Object to convert to JSON serializable type. + + Returns + ------- + object + JSON serializable version of the input object. + """ + if isinstance(obj, dict): + return {k: convert_numpy_to_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [convert_numpy_to_json_serializable(item) for item in obj] + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, (np.integer, np.floating)): + # Handle numpy scalar types + return obj.item() + else: + return obj + + +def load_fsaverage_centroids(centroids_path: str | Path) -> dict[int, npt.NDArray[float]]: + """Load fsaverage centroids from static JSON file. + + Parameters + ---------- + centroids_path : str or Path + Path to the JSON file containing centroids. + + Returns + ------- + dict[int, np.ndarray] + Dictionary mapping label IDs to their centroids in RAS coordinates. + """ + + centroids_path = Path(centroids_path) + if not centroids_path.exists(): + raise FileNotFoundError(f"Fsaverage centroids file not found: {centroids_path}") + + with open(centroids_path) as f: + centroids_data = json.load(f) + + # Convert string keys back to integers and lists back to numpy arrays + return {int(label): np.array(centroid) for label, centroid in centroids_data.items()} + + +def load_fsaverage_affine(affine_path: str | Path) -> AffineMatrix4x4: + """Load fsaverage affine matrix from static text file. + + Parameters + ---------- + affine_path : str or Path + Path to the text file containing affine matrix. + + Returns + ------- + np.ndarray + 4x4 affine transformation matrix. + """ + + affine_path = Path(affine_path) + if not affine_path.exists(): + raise FileNotFoundError(f"Fsaverage affine file not found: {affine_path}") + + affine_matrix = np.loadtxt(affine_path).astype(float) + + if affine_matrix.shape != (4, 4): + raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") + + return affine_matrix + + +def load_fsaverage_data(data_path: str | Path) -> tuple[AffineMatrix4x4, MGHHeaderDict]: + """Load fsaverage affine matrix and header fields from static JSON file. + + Parameters + ---------- + data_path : str or Path + Path to the JSON file containing combined data. + + Returns + ------- + affine_matrix : AffineMatrix4x4 + 4x4 affine transformation matrix. + header_fields : dict + Header fields needed for LTA: + - dims : list[int] + Volume dimensions [x,y,z]. + - delta : list[float] + Voxel size in mm [x,y,z]. + - Mdc : np.ndarray + 3x3 direction cosines matrix. + - Pxyz_c : np.ndarray + RAS center coordinates [x,y,z]. + + Raises + ------ + FileNotFoundError + If the data file doesn't exist. + json.JSONDecodeError + If the file is not valid JSON. + ValueError + If required fields are missing. + """ + data_path = Path(data_path) + if not data_path.exists(): + raise FileNotFoundError(f"Fsaverage data file not found: {data_path}") + + with open(data_path) as f: + data = json.load(f) + + # Verify required fields + if "affine" not in data: + raise ValueError("Required field 'affine' missing from data file") + if "header" not in data: + raise ValueError("Required field 'header' missing from data file") + + required_header_fields = ["dims", "delta", "Mdc", "Pxyz_c"] + for field in required_header_fields: + if field not in data["header"]: + raise ValueError(f"Required header field missing: {field}") + + # Convert lists back to numpy arrays + affine_matrix = np.array(data["affine"]) + header_data = MGHHeaderDict( + dims=data["header"]["dims"], + delta=data["header"]["delta"], + Mdc=np.array(data["header"]["Mdc"]), + Pxyz_c=np.array(data["header"]["Pxyz_c"]), + ) + + # Validate affine matrix shape + if affine_matrix.shape != (4, 4): + raise ValueError(f"Expected 4x4 affine matrix, got shape {affine_matrix.shape}") + + return affine_matrix, header_data diff --git a/CorpusCallosum/fastsurfer_cc.py b/CorpusCallosum/fastsurfer_cc.py new file mode 100644 index 000000000..7de15e711 --- /dev/null +++ b/CorpusCallosum/fastsurfer_cc.py @@ -0,0 +1,1099 @@ +#!/usr/bin/env python3 +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +from collections.abc import Iterable +from pathlib import Path +from time import perf_counter_ns +from typing import Literal, TypeVar, cast, get_args + +import nibabel as nib +import numpy as np +import torch +from monai.networks.nets import DenseNet +from nibabel.freesurfer.mghformat import MGHHeader +from scipy.ndimage import affine_transform + +from CorpusCallosum.data.constants import ( + CC_LABEL, + DEFAULT_INPUT_PATHS, + DEFAULT_OUTPUT_PATHS, + FSAVERAGE_CENTROIDS_PATH, + FSAVERAGE_DATA_PATH, + FSAVERAGE_MIDDLE, + THIRD_VENTRICLE_LABEL, +) +from CorpusCallosum.data.read_write import ( + MGHHeaderDict, + calc_ras_centroids_from_seg, + convert_numpy_to_json_serializable, + load_fsaverage_centroids, + load_fsaverage_data, +) +from CorpusCallosum.localization import inference as localization_inference +from CorpusCallosum.segmentation import inference as segmentation_inference +from CorpusCallosum.segmentation import segmentation_postprocessing +from CorpusCallosum.shape.contour import calculate_volume as calculate_cc_volume_contour +from CorpusCallosum.shape.postprocessing import ( + check_area_changes, + make_subdivision_mask, + offset_affine, + recon_cc_surf_measures_multi, +) +from CorpusCallosum.utils.mapping_helpers import ( + apply_transform_to_pt, + apply_transform_to_volume, + calc_mapping_to_standard_space, + map_softlabels_to_orig, +) +from CorpusCallosum.utils.types import CCMeasuresDict, SliceSelection, SubdivisionMethod +from FastSurferCNN.data_loader.conform import conform, is_conform +from FastSurferCNN.segstats import HelpFormatter +from FastSurferCNN.utils import ( + AffineMatrix4x4, + Image3d, + Image4d, + Mask3d, + Shape3d, + Vector2d, + logging, + nibabelHeader, + nibabelImage, +) +from FastSurferCNN.utils.arg_types import path_or_none +from FastSurferCNN.utils.common import SubjectDirectory, find_device +from FastSurferCNN.utils.lta import write_lta +from FastSurferCNN.utils.parallel import get_num_threads, serial_executor, shutdown_executors, thread_executor +from FastSurferCNN.utils.parser_defaults import modify_argument +from recon_surf.align_points import find_rigid + +logger = logging.get_logger(__name__) + +_TPathLike = TypeVar("_TPathLike", str, Path, Literal[None]) + +CCMeasures = Literal[ + "areas", + "thickness", + "curvature", + "midline_length", + "circularity", + "cc_index", + "total_area", + "total_perimeter", + "thickness_profile", + "curvature_subsegments", + "curvature_body", +] + + +class ArgumentDefaultsHelpFormatter(HelpFormatter): + """Help message formatter which adds default values to argument help.""" + + def _get_help_string(self, action): + """ + Add the default value to the option help message. + """ + help = action.help + if help is None: + help = '' + + if "%(default)" not in help and not getattr(action, "required", False): + if action.default is not argparse.SUPPRESS and not getattr(action.default, "DO_NOT_PRINT_DEFAULT", False): + defaulting_nargs = [argparse.OPTIONAL, argparse.ZERO_OR_MORE] + if action.option_strings or action.nargs in defaulting_nargs: + help += " (not used by default)" if action.default is None else " (default: %(default)s)" + return help + + +class _FixFloatFormattingList(list): + def __init__(self, items: Iterable, item_format_spec: str): + self._format_spec = item_format_spec + super().__init__(items) + + def __str__(self): + return "[" + ", ".join(map(lambda x: format(x, self._format_spec), self)) + "]" + + +def _do_not_print(value): + class _DoNotPrintGeneric(type(value)): + DO_NOT_PRINT_DEFAULT = True + + return _DoNotPrintGeneric(value) + + +def make_parser() -> argparse.ArgumentParser: + """Create the argument parse object for the pipeline.""" + from FastSurferCNN.utils.parser_defaults import add_arguments + + parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + + parser.add_argument( + "-v", + "--verbose", + action="count", + default=_do_not_print(0), + help="Enable verbose (pass twice for debug-output).", + ) + # Specify subject directory + subject ID, OR specify individual MRI and segmentation files + output paths + add_arguments(parser, ["sd", "sid", "conformed_name", "aseg_name", "device"]) + + def _set_help_sid(action): + action.help = "The subject id to use." + modify_argument(parser, "--sid", _set_help_sid) + + parser.add_argument( + "--num_thickness_points", + type=int, + default=100, + help="Number of points for thickness estimation." + ) + parser.add_argument( + "--subdivisions", + type=float, + nargs='*', + metavar="FRAC", + default=_FixFloatFormattingList([1 / 6, 1 / 2, 2 / 3, 3 / 4], ".3f"), + help="List of subdivision fractions for the corpus callosum subsegmentation." + "The method allows for an arbitrary number of fractions." + "By default it uses following Hofer-Frahms convention." + ) + parser.add_argument( + "--subdivision_method", + default=_do_not_print("shape"), + help="Method for contour subdivision. Options:
" + "- shape (default): Intercallosal subdivision perpendicular to intercallosal line,
" + "- vertical: orthogonal to the most anterior and posterior points in the AC/PC standardized CC contour, " + "
" + "- angular: subdivision based on equally spaced angles, as proposed by Hampel and colleagues,
" + "- eigenvector: primary direction, same as FreeSurfers mri_cc.", + choices=["shape", "vertical", "angular", "eigenvector"], + ) + parser.add_argument( + "--contour_smoothing", + type=int, + default=5, + help="Gaussian sigma for smoothing during contour detection. Higher values mean a smoother CC outline, at the " + "cost of precision.", + ) + def _slice_selection(a: str) -> SliceSelection: + if (b := a.lower()) in ("middle", "all"): + return b + return int(a) + parser.add_argument( + "--slice_selection", + type=_slice_selection, + default=_do_not_print("all"), + help="Which slices to process. Options: 'middle', 'all' (default), or a specific slice number.", + ) + + ######## OUTPUT PATHS ######### + # 4. Options for advanced, technical parameters + advanced = parser.add_argument_group( + title="Advanced options", + description="Custom output paths, useful if no standard case directory is used. Relative paths are always " + "relative to the subject_dir defined via --sd and --sid!", + ) + add_arguments(advanced, ["threads"]) + advanced.add_argument( + "--segmentation", "--seg", + type=path_or_none, + help="Output path for corpus callosum and fornix segmentation output.", + default=Path(DEFAULT_OUTPUT_PATHS["segmentation"]), + ) + advanced.add_argument( + "--segmentation_in_orig", + type=path_or_none, + help="Output path for corpus callosum and fornix segmentation output in the input MRI space.", + default=DEFAULT_OUTPUT_PATHS["segmentation_in_orig"], + ) + advanced.add_argument( + "--cc_measures", + type=path_or_none, + help="Output path for surface-based corpus callosum measures describing shape and volume for each image slice.", + default=Path(DEFAULT_OUTPUT_PATHS["cc_measures"]), + ) + advanced.add_argument( + "--cc_mid_measures", + type=path_or_none, + help="Output path for surface-based corpus callosum measures of the midslice describing CC shape and volume.", + default=DEFAULT_OUTPUT_PATHS["cc_markers"], + ) + advanced.add_argument( + "--upright_lta", + type=path_or_none, + help="Output path for upright LTA transform. This makes sure the midplane is at 128 in LR direction, " + "but no nodding correction is applied.", + default=DEFAULT_OUTPUT_PATHS["upright_lta"], + ) + advanced.add_argument( + "--upright_volume", + type=path_or_none, + help="Output path for upright volume (input image with cc_up.lta applied).", + default=None, + ) + advanced.add_argument( + "--orient_volume_lta", + type=path_or_none, + help="Output path for orientation volume LTA transform. This makes sure the midplane is the volume center, " + "the anterior and posterior commisures are on the coordinate line, and the posterior commissure is " + "at the origin - standardizing the head position.", + default=DEFAULT_OUTPUT_PATHS["orient_volume_lta"], + ) + advanced.add_argument( + "--qc_image", + type=path_or_none, + help="Output path for QC visualization image.", + default=DEFAULT_OUTPUT_PATHS["qc_image"], + ) + advanced.add_argument( + "--save_template_dir", + type=path_or_none, + help="Directory path where to save contours.txt and thickness_values.txt files. These files can be used to " + "visualize the CC shape and volume with the cc_visualization.py script.", + default=None, + ) + advanced.add_argument( + "--thickness_image", + type=path_or_none, + help="Output path for thickness image.", + default=DEFAULT_OUTPUT_PATHS["thickness_image"], + ) + advanced.add_argument( + "--surf", + dest="cc_surf", + type=path_or_none, + help="Output path for surf file for visualization in freeview, use --save_template_dir and contours.txt to " + "obtain source CC contours.", + default=DEFAULT_OUTPUT_PATHS["cc_surf"], + ) + advanced.add_argument( + "--thickness_overlay", + type=path_or_none, + help="Output path for corpus callosum thickness overlay file for visualization in freeview, use " + "--save_template_dir and thickness_values.txt to obtain source CC thickness values.", + default=DEFAULT_OUTPUT_PATHS["cc_thickness_overlay"], + ) + advanced.add_argument( + "--cc_interactive_html", "--cc_html", + dest="cc_html", + type=path_or_none, + help="Output path to the corpus callosum interactive 3D visualization HTML file.", + default=DEFAULT_OUTPUT_PATHS["cc_html"], + ) + advanced.add_argument( + "--cc_surf_vtk", + type=path_or_none, + help=f"Output path for vtk file, showing the CC 3D mesh for visualization, use --save_template_dir and " + f"contours.txt to obtain source CC contours. Example: {DEFAULT_OUTPUT_PATHS['cc_surf_vtk']}.", + default=None, + ) + advanced.add_argument( + "--softlabels_cc", + type=path_or_none, + help=f"Output path for corpus callosum softlabels, which contains the soft labels of each voxel. " + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_cc']}.", + default=None, + ) + advanced.add_argument( + "--softlabels_fn", + type=path_or_none, + help=f"Output path for fornix softlabels, which contains the soft labels of each voxel. " + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_fn']}.", + default=None, + ) + advanced.add_argument( + "--softlabels_background", + type=path_or_none, + help=f"Output path for background softlabels, which contains the probability of each voxel. " + f"Example: {DEFAULT_OUTPUT_PATHS['softlabels_background']}.", + default=None, + ) + ############ END OF OUTPUT PATHS ############ + return parser + + +def options_parse() -> argparse.Namespace: + """Parse command line arguments for the pipeline.""" + parser = make_parser() + args = parser.parse_args() + + # Reconstruct subject_dir from sd and sid (but sd might be stored as out_dir by parser_defaults) + sd_value = getattr(args, 'out_dir', None) + if sd_value and hasattr(args, 'sid') and args.sid: + args.subject_dir = Path(sd_value) / args.sid + else: + args.subject_dir = None + + # Validation logic: must use either directory approach (--sd + --sid) OR file approach (--conf_name + --aseg_name) + if sd_value: + # Using directory approach - make sure sid was also provided + if not (hasattr(args, 'sid') and args.sid): + parser.error("When using --sd, you must also provide --sid.") + elif hasattr(args, 'sid') and args.sid: + # If sid is provided without sd, that's an error + if not sd_value: + parser.error("When using --sid, you must also provide --sd.") + elif hasattr(args, 'conf_name') and args.conf_name: + # Using file approach - make sure aseg_name was also provided + if not (hasattr(args, 'aseg_name') and args.aseg_name): + parser.error("When using --conf_name, you must also provide --aseg_name.") + elif hasattr(args, 'aseg_name') and args.aseg_name: + # If aseg_name is provided without conf_name, that's an error + if not (hasattr(args, 'conf_name') and args.conf_name): + parser.error("When using --aseg_name, you must also provide --conf_name.") + else: + parser.error("You must specify either --sd and --sid OR both --conf_name and --aseg_name.") + + # If subject_dir is provided, set default paths for missing arguments + if args.subject_dir: + # Create standard FreeSurfer subdirectories + if not args.conf_name: + args.conf_name = args.subject_dir / DEFAULT_INPUT_PATHS["conf_name"] + + if not args.aseg_name: + args.aseg_name = args.subject_dir / DEFAULT_INPUT_PATHS["aseg_name"] + else: + print("WARNING: Not providing subject_dir leads to discarding of files with relative paths!") + args.subject_dir = None + for arg, path in (("--aseg_name", args.aseg_name), ("--conformed_name", args.conf_name)): + if path is None or not Path(path).is_absolute(): + parser.error( + f"When not passing --sd , arguments of --aseg_name and --conformed_name must be " + f"absolute! But the argument passed to {arg} was {path}, i.e. not absolute." + ) + + all_paths = ("segmentation", "segmentation_in_orig", "cc_measures", "upright_lta", "orient_volume_lta", + "cc_surf", "softlabels_cc", "softlabels_fn", "softlabels_background", "cc_mid_measures", + "thickness_overlay", "qc_image", "thickness_image", "cc_html") + + warnings_paths = [] + # Create parent directories for all output paths + for path_name in all_paths: + path: Path | None = getattr(args, path_name, None) + if isinstance(path, Path) and not args.subject_dir and not path.is_absolute(): + # set path to none in arguments + warnings_paths.append(path_name) + setattr(args, path_name, None) + if warnings_paths: + _warnings_paths = "' '".join(warnings_paths) + print(f"WARNING: Not writing '{_warnings_paths}', because --sd and --sid are not specified and " + f"its paths are relative.") + return args + + +def register_centroids_to_fsavg(aseg_nib: nibabelImage) \ + -> tuple[AffineMatrix4x4, AffineMatrix4x4, AffineMatrix4x4, MGHHeaderDict]: + """Perform centroid-based registration between subject and fsaverage space. + + Computes a rigid transformation between the subject's segmentation and fsaverage space + by aligning centroids of corresponding anatomical structures. + + Parameters + ---------- + aseg_nib : nibabel.analyze.SpatialImage + Subject's segmentation image. + + Returns + ------- + aseg2fsaverage_vox2vox : AffineMatrix4x4 + Transformation matrix from original to fsaverage voxel space. + aseg2fsaverage_ras2ras : AffineMatrix4x4 + Transformation matrix from original to fsaverage RAS space. + fsaverage_hires_vox2ras : AffineMatrix4x4 + High-resolution fsaverage affine matrix. + fsaverage_header : MGHHeaderDict + FSAverage header fields for LTA writing. + + Notes + ----- + The function uses pre-computed fsaverage centroids and data from static files + to perform the registration. It matches corresponding anatomical structures + between the subject's segmentation and fsaverage space. + """ + logger.info("Starting centroid registration") + + # Load pre-computed fsaverage centroids and data from static files + fsaverage_data_future = thread_executor().submit(load_fsaverage_data, FSAVERAGE_DATA_PATH) + ras_centroids_dst = load_fsaverage_centroids(FSAVERAGE_CENTROIDS_PATH) + + ras_centroids_mov = calc_ras_centroids_from_seg(aseg_nib, label_ids=list(ras_centroids_dst.keys())) + + # get the set of joint labels + joint_centroid_labels = [lbl for lbl, v in ras_centroids_mov.items() if v is not None] + + ras_centroids_mov = np.array([ras_centroids_mov[lbl] for lbl in joint_centroid_labels]).T + ras_centroids_dst = np.array([ras_centroids_dst[lbl] for lbl in joint_centroid_labels]).T + + aseg2fsaverage_ras2ras: AffineMatrix4x4 = find_rigid(p_mov=ras_centroids_mov.T, p_dst=ras_centroids_dst.T) + + # make affine that increases resolution to orig resolution + aseg_zooms_ras = np.asarray(nib.as_closest_canonical(aseg_nib).header.get_zooms()[:3]) + resolution_trans: AffineMatrix4x4 = np.diagflat(np.append(aseg_zooms_ras[[0, 2, 1]], [1])).astype(float) + + fsaverage_vox2ras, fsavg_header = fsaverage_data_future.result() + fsavg_header["delta"] = aseg_zooms_ras[[0, 2, 1]] # vox sizes in lia + # fsavg_hires_vox2ras translation should be 128 always (independent of resolution) + fsavg_hires_vox2ras: AffineMatrix4x4 = np.concatenate( + [(resolution_trans @ fsaverage_vox2ras)[:, :3], fsaverage_vox2ras[:, 3:4]], + axis=1, + ) + fsavg_header["dims"] = np.ceil(fsavg_header["dims"] @ np.linalg.inv(resolution_trans[:3, :3])).astype(int).tolist() + + # Correct fsavg_header["Pxyz_c"] by (vox_size - 1) / 2 in all three directions, because Pxyz_c is not actually in + # the center of the image, but in the center of the voxel in increasing voxel index direction, i.e. index 128 for a + # 256 image (where the center would be at 127.5). + fsavg_header["Pxyz_c"] += (aseg_zooms_ras - 1) / 2 @ fsavg_header["Mdc"] + + aseg2fsavg_vox2vox: AffineMatrix4x4 = np.linalg.inv(fsavg_hires_vox2ras) @ aseg2fsaverage_ras2ras @ aseg_nib.affine + logger.info("Centroid registration successful!") + return aseg2fsavg_vox2vox, aseg2fsaverage_ras2ras, fsavg_hires_vox2ras, fsavg_header + + +def localize_ac_pc( + orig_data: Image3d, + aseg_nib: nibabelImage, + orig2midslice_vox2vox: AffineMatrix4x4, + model_localization: DenseNet, + resample_shape: Shape3d, +) -> tuple[Vector2d, Vector2d]: + """Localize anterior and posterior commissure points in the brain. + + Uses a trained model to detect AC and PC points in mid-sagittal slices, + using the third ventricle as an anatomical reference. + + Parameters + ---------- + orig_data : np.ndarray + Array of intensity data. + aseg_nib : nibabelImage + Subject's segmentation image in native subject space. + orig2midslice_vox2vox : np.ndarray + Transformation matrix from subject/native space to fsaverage space (in lia). + model_localization : DenseNet + Trained model for AC-PC detection. + resample_shape : 3-tuple of ints + Number of slices to process. + + Returns + ------- + ac_coords : np.ndarray + AC voxel coordinates with shape (2,) containing its [y,x] positions. + pc_coords : np.ndarray + PC voxel coordinates with shape (2,) containing its [y,x] positions. + """ + num_slices_to_analyze = resample_shape[0] + resample_shape = (num_slices_to_analyze + 2,) + resample_shape[1:] # 2 for context slices + _midslices_fut = thread_executor().submit( + affine_transform, + orig_data, + np.linalg.inv(orig2midslice_vox2vox), # inverse is required for affine_transform + output_shape=resample_shape, + order=2, # unclear, why this is not order=3 + mode="constant", + cval=0, + prefilter=True, # unclear, why we are using a smoothing filter here + ) + + # get center of third ventricle from aseg and map to fsaverage space (voxel coordinates) + third_ventricle_mask = np.asarray(aseg_nib.dataobj) == THIRD_VENTRICLE_LABEL + third_ventricle_center = np.argwhere(third_ventricle_mask).mean(axis=0) + third_ventricle_center_vox = apply_transform_to_pt(third_ventricle_center, orig2midslice_vox2vox, inv=False) + + # get 5 mm of slices with 3 slices per inference (cropping num_slices_to_analyze + 2 slices around the center) + ac_coords, pc_coords = localization_inference.run_inference_on_slice( + model_localization, _midslices_fut.result(), third_ventricle_center_vox[1:], + ) + + return ac_coords, pc_coords + + +def segment_cc( + midslices: Image3d, + ac_coords: Vector2d, + pc_coords: Vector2d, + aseg_nib: nibabelImage, + model_segmentation: "torch.nn.Module", +) -> tuple[Mask3d, Image4d]: + """Segment the corpus callosum using a trained model. + + Performs corpus callosum segmentation on mid-sagittal slices using a trained model, with AC-PC points as anatomical + references. Includes post-processing to clean the cc_seg_labels. + + Parameters + ---------- + midslices : np.ndarray + Array of mid-sagittal slices in upright space and LIA-orientation. + ac_coords : np.ndarray + AC voxel coordinates with shape (2,) containing its [y,x] positions. + pc_coords : np.ndarray + PC voxel coordinates with shape (2,) containing its [y,x] positions. + aseg_nib : nibabelImage + Subject's cc_seg_labels image. + model_segmentation : torch.nn.Module + Trained model for CC cc_seg_labels. + + Returns + ------- + cc_seg_labels : np.ndarray + Binary cc_seg_labels of the corpus callosum in upright space and LIA-orientation. + cc_softlabels : np.ndarray + Soft cc_seg_labels probabilities of shape in upright space and LIA-orientation (H, W, D, C=3). + """ + pre_clean_segmentation, inputs, cc_softlabels = segmentation_inference.run_inference_on_slice( + model_segmentation, + midslices, + ac_center=ac_coords, + pc_center=pc_coords, + voxel_size=nib.as_closest_canonical(aseg_nib).header.get_zooms()[2:0:-1], # convert from RAS to LIA + ) + + cc_seg_labels, cc_volume_mask = segmentation_postprocessing.clean_cc_segmentation(pre_clean_segmentation) + + # print a warning if the cc_volume_mask touches the edge of the segmentation + if np.any(cc_volume_mask[:, [0, -1]]) or np.any(cc_volume_mask[:, :, [0, -1]]): + logger.warning("CC volume mask touches the edge of the cc_seg_labels field-of-view, CC might be truncated") + + # get voxels that were removed during cleaning + cleaned_mask = pre_clean_segmentation != cc_seg_labels + cc_softlabels[cleaned_mask, 1] = 0 + cc_softlabels[cleaned_mask, :] /= np.sum(cc_softlabels[cleaned_mask, :], axis=-1, keepdims=True) + 1e-6 + + return cc_seg_labels, cc_softlabels + + +def main( + conf_name: str | Path, + aseg_name: str | Path, + subject_dir: str | Path, + slice_selection: SliceSelection = "middle", + num_thickness_points: int = 100, + subdivisions: list[float] | None = None, + subdivision_method: SubdivisionMethod = "shape", + contour_smoothing: int = 5, + save_template_dir: str | Path | None = None, + device: str | torch.device = "auto", + upright_volume: str | Path | None = None, + segmentation: str | Path | None = None, + cc_measures: str | Path | None = None, + cc_mid_measures: str | Path | None = None, + upright_lta: str | Path | None = None, + orient_volume_lta: str | Path | None = None, + cc_surf: str | Path | None = None, + cc_thickness_overlay: str | Path | None = None, + cc_html: str | Path | None = None, + cc_surf_vtk: str | Path | None = None, + segmentation_in_orig: str | Path | None = None, + qc_image: str | Path | None = None, + thickness_image: str | Path | None = None, + softlabels_cc: str | Path | None = None, + softlabels_fn: str | Path | None = None, + softlabels_background: str | Path | None = None, +) -> None: + """Main pipeline function for corpus callosum analysis. + + This function performs the complete corpus callosum analysis pipeline including + registration, landmark detection, segmentation, and morphometry analysis. + + Parameters + ---------- + conf_name : str or Path + Path to input MRI file. + aseg_name : str or Path + Path to input segmentation file. + subject_dir : str or Path + FastSurfer/FreeSurfer subject directory and directory for output files. + slice_selection : "middle", "all" or int, default="middle" + Which slices to process. + num_thickness_points : int, default=100 + Number of points for thickness estimation. + subdivisions : list[float], optional + List of subdivision fractions for CC subsegmentation. + subdivision_method : any of "shape", "vertical", "angular", "eigenvector", default="shape" + Method for contour subdivision. + contour_smoothing : int, default=5 + Gaussian sigma for smoothing during contour detection. + save_template_dir : str or Path, optional + Directory path where to save contours.txt and thickness_values.txt files. These files can be used to visualize + the CC shape and volume in 3D. Files are only saved, if a valid directory path is passed. + device : str, default="auto" + Device to run inference on ('auto', 'cpu', 'cuda', or 'cuda:X'). + upright_volume : str or Path, optional + Path to save upright volume. + segmentation : str or Path, optional + Path to save segmentation. + cc_measures : str or Path, optional + Path to save post-processing results. + cc_mid_measures : str or Path, optional + Path to save CC markers. + upright_lta : str or Path, optional + Path to save upright LTA transform. + orient_volume_lta : str or Path, optional + Path to save orientation transform. + cc_surf : str or Path, optional + Path to save surface file. + cc_thickness_overlay : str or Path, optional + Path to save overlay file. + cc_html : str or Path, optional + Path to save HTML visualization. + cc_surf_vtk : str or Path, optional + Path to save VTK file. + segmentation_in_orig : str or Path, optional + Path to save segmentation in original space. + qc_image : str or Path, optional + Path to save QC images. + thickness_image : str or Path, optional + Path to save thickness visualization. + softlabels_cc : str or Path, optional + Path to save CC soft labels. + softlabels_fn : str or Path, optional + Path to save fornix soft labels. + softlabels_background : str or Path, optional + Path to save background soft labels. + + Notes + ----- + The function saves multiple outputs to specified paths or default locations in output_dir: + - cc_markers.json: Contains detected landmarks and measurements. + - midplane_slices.mgz: Extracted midplane slices. + - upright_volume.mgz: Volume aligned to standard orientation. + - segmentation.mgz: Corpus callosum segmentation. + - cc_postproc_results.json: Enhanced postprocessing results. + - Various visualization plots and transformation matrices. + + The pipeline consists of the following steps: + 1. Initializes environment and loads models. + 2. Registers input image to fsaverage space. + 3. Detects AC and PC points. + 4. Segments the corpus callosum. + 5. Performs enhanced post-processing analysis. + 6. Saves results and visualizations. + """ + start = perf_counter_ns() + + import sys + + if subdivisions is None: + subdivisions = [1 / 6, 1 / 2, 2 / 3, 3 / 4] + + subject_dir = Path("/dev/null/no-subject-dir" if subject_dir is None else subject_dir) + + logger.info("Starting corpus callosum analysis pipeline") + logger.info(f"Input MRI: {conf_name}") + logger.info(f"Input segmentation: {aseg_name}") + logger.info(f"Output directory: {subject_dir}") + + # Convert all paths to Path objects + sd = SubjectDirectory( + subject_dir.parent, + id=subject_dir.name, + conf_name=conf_name, + aseg_name=aseg_name, + save_template_dir=save_template_dir, + upright_volume=upright_volume, + cc_segmentation=segmentation, + cc_measures=cc_measures, + cc_mid_measures=cc_mid_measures, + upright_lta=upright_lta, + cc_orient_volume_lta=orient_volume_lta, + cc_surf=cc_surf, + cc_thickness_overlay=cc_thickness_overlay, + cc_html=cc_html, + cc_mesh=cc_surf_vtk, + cc_orig_segfile=segmentation_in_orig, + cc_qc_image=qc_image, + cc_thickness_image=thickness_image, + cc_softlabels_cc=softlabels_cc, + cc_softlabels_fn=softlabels_fn, + cc_softlabels_background=softlabels_background, + ) + + # Validate subdivision fractions + if any(i < 0 or i > 1 for i in subdivisions): + logger.error(f"Subdivision fractions must be between 0 and 1, but got: {subdivisions}") + sys.exit(1) + + #### setup variables + io_futures = [] + + # load models + device = find_device(device) + logger.info(f"Using device: {device}") + + logger.info("Loading models") + _model_localization = thread_executor().submit(localization_inference.load_model, device=device) + _model_segmentation = thread_executor().submit(segmentation_inference.load_model, device=device) + + _aseg_fut = thread_executor().submit(nib.load, sd.filename_by_attribute("aseg_name")) + orig = cast(nibabelImage, nib.load(sd.conf_name)) + + # check that the image is conformed, i.e. isotropic 1mm voxels, 256^3 size, LIA orientation + if not is_conform(orig, vox_size=None, img_size=None, orientation=None): + logger.info("Internally conforming orig to soft-LIA.") + orig = conform(orig, vox_size=None, img_size=None, orientation=None) + + # 5 mm around the midplane (guaranteed to be aligned RAS by as_closest_canonical) + vox_size_ras: tuple[float, float, float] = nib.as_closest_canonical(orig).header.get_zooms() + vox_size = vox_size_ras[0], vox_size_ras[2], vox_size_ras[1] # convert from RAS to LIA + slices_to_analyze = int(np.ceil(5 / vox_size[0])) + # slices_to_analyze must be odd + if slices_to_analyze % 2 == 0: + slices_to_analyze += 1 + + logger.info( + f"Segmenting {slices_to_analyze} slices (5 mm width at {vox_size[0]:.3f} mm resolution, " + "center around the mid-sagittal plane)" + ) + + aseg_img = cast(nibabelImage, _aseg_fut.result()) + + if not np.allclose(aseg_img.affine, orig.affine): + logger.error("Input MRI and segmentation are not aligned! Please check your input files.") + sys.exit(1) + + logger.info("Performing centroid registration to fsaverage space") + orig2fsavg_vox2vox, orig2fsavg_ras2ras, fsavg_vox2ras, _fsavg_header_dict = register_centroids_to_fsavg(aseg_img) + fsavg_header = init_mgh_header(orig.header, _fsavg_header_dict) + + # start saving upright volume, this is the image in fsaverage space but not yet oriented via AC-PC + if sd.has_attribute("upright_volume"): + # upright == fsaverage-aligned + io_futures.append( + thread_executor().submit( + apply_transform_to_volume, + orig, + orig2fsavg_vox2vox, + save_vox2ras=fsavg_vox2ras, + output_path=sd.filename_by_attribute("upright_volume"), + output_size=fsavg_header["dims"][:3], + ) + ) + + # calculate affine for segmentation volume + fsavg2midslice_vox2vox: AffineMatrix4x4 = offset_affine([-FSAVERAGE_MIDDLE / vox_size[0], 0, 0]) + orig2midslice_vox2vox = fsavg2midslice_vox2vox @ orig2fsavg_vox2vox + + # calculate vox2vox for input resampling volumes + def _orig2midslab_vox2vox(additional_context: int = 0) -> AffineMatrix4x4: + fsavg2midslab = offset_affine([slices_to_analyze // 2 + additional_context // 2, 0, 0]) + # first, orig->fsaverage, then fsaverage->midslab (all in vox2vox) + return fsavg2midslab @ orig2midslice_vox2vox + + # first, midslice->fsaverage in vox2vox, then vox2ras in fsaverage space + fsavg2midslab_vox2vox = offset_affine([slices_to_analyze // 2, 0, 0]) @ fsavg2midslice_vox2vox + fsaverage_midslab_vox2ras: AffineMatrix4x4 = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_vox2vox) + + + #### do localization and segmentation inference + logger.info("Starting AC/PC localization") + target_shape: tuple[int, int, int] = (slices_to_analyze, fsavg_header["dims"][1], fsavg_header["dims"][2]) + # predict ac and pc coordinates in upright AS space + ac_coords_vox, pc_coords_vox = localize_ac_pc( + np.asarray(orig.dataobj), + aseg_img, + _orig2midslab_vox2vox(additional_context=2), + _model_localization.result(), + target_shape, + ) + logger.info("Starting corpus callosum segmentation") + num_context = 8 # 8 extra in x-direction for context slices + target_shape: Shape3d = (slices_to_analyze + num_context, fsavg_header["dims"][1], fsavg_header["dims"][2]) + midslices: Image3d = affine_transform( + np.asarray(orig.dataobj), + np.linalg.inv(_orig2midslab_vox2vox(additional_context=num_context)), # inverse is required for affine_transform + output_shape=target_shape, + order=2, # @ClePol unclear, why this is not order=3 + mode="constant", + cval=0, + prefilter=True, # unclear, why we are using a smoothing filter here + ) + cc_fn_seg_labels, cc_fn_softlabels = segment_cc( + midslices, + ac_coords_vox, + pc_coords_vox, + aseg_img, + _model_segmentation.result(), + ) + + # save segmentation softlabels + for i, (attr, name) in enumerate((("background",) * 2, ("cc", "Corpus Callosum"), ("fn", "Fornix"))): + if sd.has_attribute(f"cc_softlabels_{attr}"): + logger.info(f"Saving {name} softlabels to {sd.filename_by_attribute(f'cc_softlabels_{attr}')}") + io_futures.append(thread_executor().submit( + nib.save, + nib.MGHImage(cc_fn_softlabels[..., i], fsaverage_midslab_vox2ras, orig.header), + sd.filename_by_attribute(f"cc_softlabels_{attr}"), + )) + + # Create a temporary segmentation image with proper affine for enhanced postprocessing + # Process slices based on selection mode + + logger.info(f"Processing slices with selection mode: {slice_selection}") + slice_results, slice_io_futures, cc_contours, cc_mesh = recon_cc_surf_measures_multi( + segmentation=cc_fn_seg_labels, + slice_selection=slice_selection, + upright_header=fsavg_header, + fsavg2midslab_vox2vox=fsavg2midslab_vox2vox, + fsavg_vox2ras=fsavg_vox2ras, + orig2fsavg_vox2vox=orig2fsavg_vox2vox, + midslices=midslices, + ac_coords_vox=ac_coords_vox, + pc_coords_vox=pc_coords_vox, + num_thickness_points=num_thickness_points, + subdivisions=subdivisions, + subdivision_method=cast(SubdivisionMethod, subdivision_method), + contour_smoothing=contour_smoothing, + subject_dir=sd, + ) + io_futures.extend(slice_io_futures) + + outer_contours = [slice_result["split_contours"][0] for slice_result in slice_results] + + if len(outer_contours) > 1 and not check_area_changes(outer_contours): + logger.warning( + "Large area changes detected between consecutive slices, this is likely due to a segmentation error." + ) + + # Get middle slice result + middle_slice_result: CCMeasuresDict = slice_results[len(slice_results) // 2] + + + # save segmentation labels, this + if sd.has_attribute("cc_segmentation"): + sd.filename_by_attribute("cc_segmentation").parent.mkdir(exist_ok=True, parents=True) + io_futures.append(thread_executor().submit( + nib.save, + nib.MGHImage(cc_fn_seg_labels, fsaverage_midslab_vox2ras, orig.header), + sd.filename_by_attribute("cc_segmentation"), + )) + # map soft labels to original space (in parallel because this takes a while, and we only do it to save the labels) + if sd.has_attribute("cc_orig_segfile"): + if len(middle_slice_result["split_contours"]) <= 5: + cc_subseg_midslice = make_subdivision_mask( + (cc_fn_seg_labels.shape[1], cc_fn_seg_labels.shape[2]), + middle_slice_result["split_contours"], + vox2ras=fsavg_vox2ras @ np.linalg.inv(fsavg2midslice_vox2vox), + ) + else: + logger.warning("Too many subsegments for lookup table, skipping sub-division of output segmentation.") + cc_subseg_midslice = None + # if num_threads is not large enough (>1), this might be blocking ; serial_executor runs the function in submit + executor = thread_executor() if get_num_threads() > 2 else serial_executor() + io_futures.append(executor.submit( + map_softlabels_to_orig, + cc_fn_softlabels=cc_fn_softlabels, + orig=orig, + orig_space_segmentation_path=sd.filename_by_attribute("cc_orig_segfile"), + orig2slab_vox2vox=_orig2midslab_vox2vox(), + cc_subseg_midslice=cc_subseg_midslice, + orig2midslice_vox2vox=orig2midslice_vox2vox, + )) + + metrics: tuple[CCMeasures] = get_args(CCMeasures) + + # Record key metrics for middle slice + output_metrics_middle_slice = {metric: middle_slice_result[metric] for metric in metrics} + + # Create enhanced output dictionary with all slice results + per_slice_output_dict = { + "slices": [convert_numpy_to_json_serializable({metric: result[metric] for metric in metrics}) + for result in slice_results], + } + + ########## Save outputs ########## + additional_metrics = {} + if len(outer_contours) > 1: + cc_volume_voxel = segmentation_postprocessing.get_cc_volume_voxel( + desired_width_mm=5, + cc_mask=np.equal(cc_fn_seg_labels, CC_LABEL), + voxel_size=vox_size, # in LIA order + ) + logger.info(f"CC volume voxel: {cc_volume_voxel}") + cc_volume_contour = calculate_cc_volume_contour(cc_contours, width=5.0) + logger.info(f"CC volume contour: {cc_volume_contour}") + + additional_metrics["cc_5mm_volume"] = cc_volume_voxel + additional_metrics["cc_5mm_volume_pv_corrected"] = cc_volume_contour + + # get ac and pc in all spaces + ac_coords_3d = np.hstack((FSAVERAGE_MIDDLE, ac_coords_vox)) + pc_coords_3d = np.hstack((FSAVERAGE_MIDDLE, pc_coords_vox)) + standardized2orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig = ( + calc_mapping_to_standard_space(orig, ac_coords_3d, pc_coords_3d, orig2fsavg_vox2vox) + ) + + # write output dict as csv + additional_metrics["ac_center"] = ac_coords_orig + additional_metrics["pc_center"] = pc_coords_orig + additional_metrics["ac_center_oriented_volume"] = ac_coords_standardized + additional_metrics["pc_center_oriented_volume"] = pc_coords_standardized + additional_metrics["ac_center_upright"] = ac_coords_3d + additional_metrics["pc_center_upright"] = pc_coords_3d + additional_metrics["slices_in_segmentation"] = slices_to_analyze + additional_metrics["voxel_size"] = np.asarray(orig.header.get_zooms(), dtype=float).tolist() + additional_metrics["num_thickness_points"] = num_thickness_points + additional_metrics["subdivision_method"] = subdivision_method + additional_metrics["subdivision_ratios"] = subdivisions + additional_metrics["contour_smoothing"] = contour_smoothing + additional_metrics["slice_selection"] = slice_selection + + # QC checks + if len(outer_contours) > 1: + max_vol = max(cc_volume_voxel, cc_volume_contour) + if max_vol > 0 and abs(cc_volume_voxel - cc_volume_contour) / max_vol > 0.2: + logger.warning( + f"QC flag: CC volume estimates differ by more than 20% " + f"(voxel: {cc_volume_voxel:.2f}, contour: {cc_volume_contour:.2f})", + "this can happen if contour creation failed for some slices" + ) + + cc_index = output_metrics_middle_slice.get("cc_index") + if cc_index is not None and cc_index > 2: + logger.warning( + f"QC flag: CC index is high ({cc_index:.2f} > 2), segmentation or contour creation may be incorrect" + ) + + midline_length = output_metrics_middle_slice.get("midline_length") + if midline_length is not None and midline_length < 30: + logger.warning( + f"QC flag: CC midline length is short ({midline_length:.2f}mm < 30mm), endpoints may be " + "incorrectly detected or contour creation may have failed" + ) + + if sd.has_attribute("cc_mid_measures"): + sd.filename_by_attribute('cc_mid_measures').parent.mkdir(exist_ok=True, parents=True) + io_futures.append(thread_executor().submit( + save_cc_measures_json, + sd.filename_by_attribute('cc_mid_measures'), + output_metrics_middle_slice | additional_metrics, + )) + + if sd.has_attribute("cc_measures"): + sd.filename_by_attribute("cc_measures").parent.mkdir(exist_ok=True, parents=True) + io_futures.append(thread_executor().submit( + save_cc_measures_json, + sd.filename_by_attribute("cc_measures"), + per_slice_output_dict | additional_metrics, + )) + + # save lta to fsaverage space + + if sd.has_attribute("upright_lta"): + sd.filename_by_attribute("upright_lta").parent.mkdir(exist_ok=True, parents=True) + logger.info(f"Saving LTA to fsaverage space: {sd.filename_by_attribute('upright_lta')}") + io_futures.append(thread_executor().submit( + write_lta, + sd.filename_by_attribute("upright_lta"), + orig2fsavg_ras2ras, + sd.filename_by_attribute("aseg_name"), + aseg_img.header, + "fsaverage", + fsavg_header, + )) + + if sd.has_attribute("cc_orient_volume_lta"): + sd.filename_by_attribute("cc_orient_volume_lta").parent.mkdir(exist_ok=True, parents=True) + # save lta to standardized space (fsaverage + nodding + ac to center) + orig2standardized_ras2ras = orig.affine @ np.linalg.inv(standardized2orig_vox2vox) @ np.linalg.inv(orig.affine) + logger.info(f"Saving LTA to standardized space: {sd.filename_by_attribute('cc_orient_volume_lta')}") + io_futures.append(thread_executor().submit( + write_lta, + sd.filename_by_attribute("cc_orient_volume_lta"), + orig2standardized_ras2ras, + sd.conf_name, + orig.header, + sd.conf_name, + orig.header, + )) + + # this waits for all io to finish + for fut in io_futures: + e = fut.exception() + if e and isinstance(e, Exception): + logger.exception(e) + shutdown_executors() + + duration = (perf_counter_ns() - start) / 1e9 + logger.info(f"CorpusCallosum analysis pipeline completed successfully in {duration:.2f} seconds.") + + +def init_mgh_header(header: nibabelHeader, header_dict: MGHHeaderDict) -> MGHHeader: + """ + Generates a MGHHeader object from a header and a header dictionary. + + Parameters + ---------- + header : nibabelHeader + The header object used to initialize the generated header. + header_dict : MGHHeaderDict + A dictionary of values to overwrite in the generated header. + + Returns + ------- + MGHHeader + The header updated with values in header_dict. + """ + new_header: MGHHeader = MGHHeader.from_header(header) + if "dims" in header_dict: + new_header["dims"] = np.append(header_dict["dims"], [1]) + for key in ("delta", "Pxyz_c", "Mdc"): + if key in header_dict: + new_header[key] = header_dict[key] + return new_header + + +def save_cc_measures_json(cc_mid_measure_file: Path, metrics: dict[str, object]): + """Save JSON metrics file.""" + # Convert numpy arrays to lists for JSON serialization + logger.info(f"Saving CC markers to {cc_mid_measure_file}") + cc_mid_measure_file.parent.mkdir(exist_ok=True, parents=True) + with open(cc_mid_measure_file, "w") as f: + json.dump(convert_numpy_to_json_serializable(metrics), f, indent=4) + + +if __name__ == "__main__": + options = options_parse() + + # Set up logging if verbose mode is enabled + logging.setup_logging(None, options.verbose) # Log to stdout only + + main( + conf_name=options.conf_name, + aseg_name=options.aseg_name, + subject_dir=options.subject_dir, + slice_selection=options.slice_selection, + num_thickness_points=options.num_thickness_points, + subdivisions=list(options.subdivisions), + subdivision_method=str(options.subdivision_method), + contour_smoothing=options.contour_smoothing, + save_template_dir=options.save_template_dir, + device=options.device, + upright_volume=options.upright_volume, + segmentation=options.segmentation, + cc_measures=options.cc_measures, + cc_mid_measures=options.cc_mid_measures, + upright_lta=options.upright_lta, + orient_volume_lta=options.orient_volume_lta, + cc_surf=options.cc_surf, + cc_thickness_overlay=options.thickness_overlay, + cc_html=options.cc_html, + cc_surf_vtk=options.cc_surf_vtk, + segmentation_in_orig=options.segmentation_in_orig, + qc_image=options.qc_image, + thickness_image=options.thickness_image, + softlabels_cc=options.softlabels_cc, + softlabels_fn=options.softlabels_fn, + softlabels_background=options.softlabels_background, + ) diff --git a/CorpusCallosum/localization/__init__.py b/CorpusCallosum/localization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/localization/inference.py b/CorpusCallosum/localization/inference.py new file mode 100644 index 000000000..1837864a7 --- /dev/null +++ b/CorpusCallosum/localization/inference.py @@ -0,0 +1,254 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Literal, cast + +import numpy as np +import torch +from monai import transforms +from monai.networks.nets import DenseNet + +from CorpusCallosum.transforms.localization import CropAroundACPCFixedSize +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML +from CorpusCallosum.utils.types import Points2dType +from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults +from FastSurferCNN.download_checkpoints import main as download_checkpoints +from FastSurferCNN.utils import Image3d, Vector2d, Vector3d +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +PATCH_SIZE = (64, 64) + + +def load_model(device: torch.device) -> DenseNet: + """Load trained numerical localization model from checkpoint. + + Parameters + ---------- + device : torch.device + Device to load model to. + + Returns + ------- + DenseNet + Loaded and initialized model in evaluation mode. + """ + + # Initialize model architecture (must match training) + model = DenseNet( # densenet201 + spatial_dims=2, + in_channels=3, + out_channels=4, + init_features=64, + growth_rate=32, + block_config=(6, 12, 48, 32), + bn_size=4, + act=("relu", {"inplace": True}), + norm=("batch", {"affine": True}), + dropout_prob=0.2 + ) + + download_checkpoints(cc=True) + cc_config = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + checkpoint_path = FASTSURFER_ROOT / cc_config['localization'] + + # Load state dict + if isinstance(checkpoint_path, str) or isinstance(checkpoint_path, Path): + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True) + if isinstance(state_dict, dict) and 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + else: + state_dict = checkpoint_path + + model.load_state_dict(state_dict) + model = model.to(device) + model.eval() + return model + + +def get_transforms() -> transforms.Compose: + """Get preprocessing transforms for inference. + + Returns + ------- + transforms.Compose + Composed transform pipeline including: + - Intensity scaling to [0,1] + - Fixed size cropping around AC-PC points + """ + tr = [ + transforms.ScaleIntensityd(keys=['image'], minv=0, maxv=1), + CropAroundACPCFixedSize(keys=['image'], fixed_size=PATCH_SIZE, random_translate=0), + ] + return transforms.Compose(tr) + + +def preprocess_volume( + image_volume: np.ndarray, + center_pt: Vector3d, + transform: transforms.Transform | None = None +) -> dict[str, torch.Tensor | tuple[int, ...]]: + """Preprocess a volume for inference. + + Parameters + ---------- + image_volume : np.ndarray + Input image volume of shape (W, W, D) in RAS. + center_pt : np.ndarray + Center point coordinates for cropping on the slice with shape (3,). + transform : transforms.Transform or None, optional + Custom transform pipeline, by default None. + If None, uses default transforms from get_transforms(). + + Returns + ------- + dict[str, torch.Tensor | tuple[int, ...]] + Dictionary containing preprocessed image tensor. + """ + if transform is None: + transform = get_transforms() + + # During training we used AC/PC coordinates, but during inference we approximate this by the center of the third + # ventricle. Therefore we put in the third ventricle center as dummy AC/PC coordinates for cropping the image. + sample = {"image": image_volume[None], "AC_center": center_pt[1:][None], "PC_center": center_pt[1:][None]} + + # Apply transforms + transformed = transform(sample) + + # Add batch dimension if needed + if torch.is_tensor(transformed["image"]): + if transformed["image"].ndim == 3: + transformed["image"] = transformed["image"].unsqueeze(0) + + return transformed + +def predict( + model: torch.nn.Module, + image_volume: Image3d, + patch_center: np.ndarray, + device: torch.device | None = None, + transform: transforms.Transform | None = None + ) -> tuple[Points2dType, Points2dType, tuple[int, int]]: + """ + Run inference on an image volume + + Parameters + ---------- + model : DenseNet + Trained model for inference. + image_volume : np.ndarray + Input volume as numpy array. + patch_center : np.ndarray + Initial center point estimate for cropping. + device : torch.device, optional + Device to run inference on, by default None. + transform : transforms.Transform, optional + Custom transform pipeline, defaults to preconfigured transforms of `get_transforms`. + + Returns + ------- + pc_ccord : np.ndarray + Predicted PC coordinates. + ac_coord : np.ndarray + Predicted AC coordinates. + crop_offsets : pair of ints + Crop offsets (left, top). + """ + if device is None: + device = next(model.parameters()).device + + # prepend zero to third_ventricle_center + patch_center_3d = np.concatenate([np.zeros(1), patch_center]) + + # Preprocess + t_dict = preprocess_volume(image_volume, patch_center_3d, transform) + + transformed_original = cast(torch.Tensor, t_dict["image"]) + inputs = transformed_original.to(device) + + inputs = inputs.transpose(0, 1) + inputs = inputs.unfold(0, 3, 1).transpose(1, -1)[..., 0] + + # Run inference + with torch.no_grad(): + outputs = model(inputs) * torch.as_tensor([PATCH_SIZE + PATCH_SIZE], device=device) + + crop_left, crop_top = cast(tuple[int, int], t_dict["crop_left"]), cast(tuple[int, int], t_dict["crop_top"]) + t_crops = [(crop_left + crop_top) * 2] + outs: np.ndarray[tuple[int, Literal[4]], np.dtype[np.float_]] + outs = outputs.cpu().numpy() + np.asarray(t_crops, dtype=float) + crop_offsets: tuple[int, int] = (crop_left[0], crop_top[0]) + return outs[:, :2], outs[:, 2:], crop_offsets + + +def run_inference_on_slice( + model: DenseNet, + image_slab: Image3d, + center_pt: Vector2d, + num_iterations: int = 2, + debug_output: str | None = None, +) -> tuple[Vector2d, Vector2d]: + """Run inference on a single slice to detect AC and PC points. + + Parameters + ---------- + model : torch.nn.Module + Trained model for AC-PC detection. + image_slab : np.ndarray + 3D image mid-slices to run inference on in RAS. + center_pt : np.ndarray + Initial center point estimate for cropping. + num_iterations : int, default=2 + Number of refinement iterations to run. + debug_output : str, optional + Path to save debug visualization. + + Returns + ------- + ac_coords : np.ndarray + Detected AC voxel coordinates with shape (2,) containing its [y,x] positions. + pc_coords : np.ndarray + Detected PC voxel coordinates with shape (2,) containing its [y,x] positions. + """ + + if num_iterations < 1: + raise ValueError("localization inference with less than 1 iteration is invalid!") + + pc_coords, ac_coords = center_pt[None], center_pt[None] + crop_left, crop_top = 0, 0 + # Run inference + for _ in range(num_iterations): + pc_coords, ac_coords, (crop_left, crop_top) = predict(model, image_slab, center_pt) + center_pt = np.mean(np.stack([ac_coords, pc_coords], axis=0), axis=(0, 1)) + # average ac and pc coords across sagittal slices + _pc_coords = np.mean(pc_coords, axis=0) + _ac_coords = np.mean(ac_coords, axis=0) + + if debug_output is not None: + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + fig, ax = plt.subplots(1, 1, figsize=(10, 8)) + ax.imshow(image_slab[image_slab.shape[0] // 2, :, :], cmap='gray') + # Plot points on all views + ax.scatter(pc_coords[:, 1], pc_coords[:, 0], c='r', marker='x', label='PC') + ax.scatter(ac_coords[:, 1], ac_coords[:, 0], c='b', marker='x', label='AC') + # make a box where the crop is + ax.add_patch(Rectangle((crop_top, crop_left), 64, 64, fill=False, color='r', linewidth=2)) + plt.savefig(debug_output, bbox_inches='tight') + plt.close() + + return _ac_coords, _pc_coords diff --git a/CorpusCallosum/paint_cc_into_pred.py b/CorpusCallosum/paint_cc_into_pred.py new file mode 100644 index 000000000..39d0e61a0 --- /dev/null +++ b/CorpusCallosum/paint_cc_into_pred.py @@ -0,0 +1,477 @@ +# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# IMPORTS + +import argparse +import sys +from functools import partial +from pathlib import Path +from typing import TypeVar, cast + +import nibabel as nib +import numpy as np +from numpy import typing as npt +from scipy import ndimage + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import FORNIX_LABEL, SUBSEGMENT_LABELS +from FastSurferCNN.data_loader.conform import is_conform +from FastSurferCNN.reduce_to_aseg import reduce_to_aseg_and_save +from FastSurferCNN.utils.arg_types import path_or_none +from FastSurferCNN.utils.brainvolstats import mask_in_array +from FastSurferCNN.utils.parallel import thread_executor + +_T = TypeVar("_T", bound=np.number) + +logger = logging.get_logger(__name__) + +HELPTEXT = """ +Script to add corpus callosum segmentation (CC, FreeSurfer IDs 251-255) to +deep-learning prediction (e.g. aparc.DKTatlas+aseg.deep.mgz). + + +USAGE: +paint_cc_into_pred -in_cc -in_pred -out + + +Dependencies: + Python 3.8+ + + Nibabel to read and write FreeSurfer data + http://nipy.org/nibabel/ + +Original Author: Leonie Henschel +Date: Jul-10-2020 + +""" + + +def argument_parse(): + """Create a command line interface and return command line options. + """ + parser = make_parser() + + args = parser.parse_args() + + if args.input_cc is None or args.input_pred is None or args.output is None: + sys.exit("ERROR: Please specify input and output segmentations") + + return args + + +def make_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(usage=HELPTEXT) + parser.add_argument( + "--input_cc", + "-in_cc", + dest="input_cc", + type=Path, + required=True, + help="path to input segmentation with Corpus Callosum (IDs 251-255 in FreeSurfer space)", + ) + parser.add_argument( + "--input_pred", + "-in_pred", + dest="input_pred", + type=Path, + required=True, + help="path to input segmentation Corpus Callosum should be added to.", + ) + parser.add_argument( + "--output", + "-out", + dest="output", + type=Path, + required=True, + help="path to output (input segmentation + added CC)", + ) + parser.add_argument( + "--reduce_to_aseg", + "-aseg", + dest="aseg", + type=path_or_none, + required=False, + help="optionally also reduce the resulting segmentation to aseg and save separately.", + default=None, + ) + return parser + + +def paint_in_cc(pred: npt.NDArray[np.int_], + aseg_cc: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]: + """Paint corpus callosum segmentation into aseg+dkt segmentation map. + + Parameters + ---------- + pred : npt.NDArray[np.int_] + Deep-learning segmentation map. + aseg_cc : npt.NDArray[np.int_] + Aseg segmentation with CC. + + Returns + ------- + npt.NDArray[np.int_] + Segmentation map with added CC. + + Notes + ----- + This function modifies the original array and does not create a copy. + The CC labels (251-255) from aseg_cc are copied into pred. + """ + cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS) + + # Count what's being replaced + replaced_labels = pred[cc_mask] + num_wm_replaced = np.sum((replaced_labels == 2) | (replaced_labels == 41)) + num_other_replaced = np.sum((replaced_labels != 0) & (replaced_labels != 2) & (replaced_labels != 41)) + num_background_replaced = np.sum(replaced_labels == 0) + + logger.info(f"Painting CC: {np.sum(cc_mask)} voxels (replacing {num_wm_replaced} WM, " + f"{num_background_replaced} background, {num_other_replaced} other)") + + pred[cc_mask] = aseg_cc[cc_mask] + return pred + +def _fill_gaps_in_direction( + corrected_pred: npt.NDArray[np.int_], + potential_fill: npt.NDArray[np.bool_], + source_binary: npt.NDArray[np.bool_], + target_binary: npt.NDArray[np.bool_], + x_slice: int, + direction: str, + max_gap_voxels: int, + fillable_labels: set[int] +) -> int: + """Fill gaps between source and target masks in a specific direction. + + Parameters + ---------- + corrected_pred : npt.NDArray[np.int_] + The segmentation array to modify in place. + potential_fill : npt.NDArray[np.bool_] + 2D mask of potential fill regions for this slice. + source_binary : npt.NDArray[np.bool_] + 2D binary mask of source structure (e.g., CC). + target_binary : npt.NDArray[np.bool_] + 2D binary mask of target structure (e.g., ventricle). + x_slice : int + The x-coordinate of the current slice. + direction : str + Either 'inferior-superior' (iterate over z) or 'anterior-posterior' (iterate over y). + max_gap_voxels : int + Maximum gap size in voxels for this direction. + fillable_labels : set[int] + Set of label values that can be replaced (e.g., {0, 2, 41} for background and WM). + + Returns + ------- + int + Number of voxels filled. + """ + voxels_filled = 0 + + if direction == 'inferior-superior': + # Iterate over z dimension + for z in range(potential_fill.shape[1]): + potential_fill_line = potential_fill[:, z] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + source_line = source_binary[:, z] + target_line = target_binary[:, z] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # Check that both source and target are connected to the gap + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(source_line & dilated_gap_mask): + continue + if not np.any(target_line & dilated_gap_mask): + continue + + # Get the target label from adjacent target voxels + target_label_location = np.where(target_line & dilated_gap_mask)[0] + if len(target_label_location) == 0: + continue + target_label = corrected_pred[x_slice, target_label_location[0], z] + + # Check gap size + if np.sum(gap_mask) > max_gap_voxels: + continue + + # Fill voxels that have fillable labels + current_labels = corrected_pred[x_slice, :, z] + fill_mask = gap_mask & np.isin(current_labels, list(fillable_labels)) + voxels_filled += np.sum(fill_mask) + corrected_pred[x_slice, :, z][fill_mask] = target_label + + elif direction == 'anterior-posterior': + # Iterate over y dimension + for y in range(potential_fill.shape[0]): + potential_fill_line = potential_fill[y, :] + labeled_gaps, num_gaps = ndimage.label(potential_fill_line) + source_line = source_binary[y, :] + target_line = target_binary[y, :] + + for gap_label in range(1, num_gaps + 1): + gap_mask = labeled_gaps == gap_label + + # Check that both source and target are connected to the gap + dilated_gap_mask = ndimage.binary_dilation(gap_mask, iterations=1) + if not np.any(source_line & dilated_gap_mask): + continue + if not np.any(target_line & dilated_gap_mask): + continue + + # Get the target label from adjacent target voxels + target_label_location = np.where(target_line & dilated_gap_mask)[0] + if len(target_label_location) == 0: + continue + target_label = corrected_pred[x_slice, y, target_label_location[0]] + + # Check gap size + if np.sum(gap_mask) > max_gap_voxels: + continue + + # Fill voxels that have fillable labels + current_labels = corrected_pred[x_slice, y, :] + fill_mask = gap_mask & np.isin(current_labels, list(fillable_labels)) + voxels_filled += np.sum(fill_mask) + corrected_pred[x_slice, y, :][fill_mask] = target_label + + return voxels_filled + + +def _fill_gaps_between_structures( + corrected_pred: npt.NDArray[np.int_], + source_mask: npt.NDArray[np.bool_], + target_mask: npt.NDArray[np.bool_], + voxel_size: tuple[float, float, float], + close_gap_size_mm: float, + fillable_labels: set[int], + description: str +) -> int: + """Fill small gaps between two structures. + + Parameters + ---------- + corrected_pred : npt.NDArray[np.int_] + The segmentation array to modify in place. + source_mask : npt.NDArray[np.bool_] + 3D binary mask of source structure (e.g., CC). + target_mask : npt.NDArray[np.bool_] + 3D binary mask of target structure (e.g., ventricle or background). + voxel_size : tuple[float, float, float] + Voxel size in mm. + close_gap_size_mm : float + Maximum gap size in mm. + fillable_labels : set[int] + Set of label values that can be replaced. + description : str + Description for logging. + + Returns + ------- + int + Number of voxels filled. + """ + # Convert mm gap size to voxels + max_gap_vox_anterior_posterior = int(np.ceil(close_gap_size_mm / voxel_size[1])) + max_gap_vox_inferior_superior = int(np.ceil(close_gap_size_mm / voxel_size[2])) + max_gap_vox_max = max(max_gap_vox_anterior_posterior, max_gap_vox_inferior_superior) + + voxels_filled = 0 + + # Process each slice independently + for x in range(corrected_pred.shape[0]): + source_slice = source_mask[x] + target_slice = target_mask[x] + + # Skip slices without both structures + if not (source_slice.any() and target_slice.any()): + continue + + # Create binary masks for this slice + source_binary = source_slice.astype(bool) + target_binary = target_slice.astype(bool) + + # Dilate both masks to find potential connection points + source_dilated = ndimage.binary_dilation(source_binary, iterations=max_gap_vox_max) + target_dilated = ndimage.binary_dilation(target_binary, iterations=max_gap_vox_max) + + # Find voxels that are adjacent to both structures but not part of either + potential_fill = (source_dilated & target_dilated) & ~(source_binary | target_binary) + + # Fill gaps in inferior-superior direction + voxels_filled += _fill_gaps_in_direction( + corrected_pred, potential_fill, source_binary, target_binary, + x, 'inferior-superior', max_gap_vox_inferior_superior, fillable_labels + ) + + # Fill gaps in anterior-posterior direction + voxels_filled += _fill_gaps_in_direction( + corrected_pred, potential_fill, source_binary, target_binary, + x, 'anterior-posterior', max_gap_vox_anterior_posterior, fillable_labels + ) + + if voxels_filled > 0: + logger.info(f"Filled {voxels_filled} voxels {description}") + + return voxels_filled + + +def correct_wm_ventricles( + aseg_cc: npt.NDArray[np.int_], + fornix_mask: npt.NDArray[np.bool_], + voxel_size: tuple[float, float, float], + close_gap_size_mm: float = 3.0 +) -> npt.NDArray[np.int_]: + """Fill small gaps between corpus callosum, ventricles, and background. + + This function performs two gap-filling operations: + 1. Fills WM and background gaps between CC and ventricles with ventricle labels + 2. Fills WM gaps between CC and background with background label + + Note: Fornix and non-CC-connected WM component removal are intentionally not implemented + in this function as they have been removed from the processing pipeline. + + Parameters + ---------- + aseg_cc : npt.NDArray[np.int_] + Aseg segmentation with CC already painted in. + fornix_mask : npt.NDArray[np.bool_] + Mask of the fornix. Not currently used (kept for interface compatibility). + voxel_size : tuple[float, float, float] + Voxel size of the aseg image in mm. + close_gap_size_mm : float, default=3.0 + Maximum size of the gap to fill in millimeters. + + Returns + ------- + npt.NDArray[np.int_] + Corrected segmentation map with filled gaps. + """ + # Create a copy to avoid modifying the original + corrected_pred = aseg_cc.copy() + + # Get CC mask (labels 251-255) + cc_mask = mask_in_array(aseg_cc, SUBSEGMENT_LABELS) + + # Get ventricle masks (left=4, right=43) + ventricle_mask = (aseg_cc == 4) | (aseg_cc == 43) + + # Get background mask + background_mask = aseg_cc == 0 + + # 1. Fill gaps between CC and ventricles (replace WM and background with ventricle labels) + _fill_gaps_between_structures( + corrected_pred, cc_mask, ventricle_mask, voxel_size, close_gap_size_mm, + fillable_labels={0, 2, 41}, # background and WM + description="between CC and ventricles (WM/background → ventricle)" + ) + + # 2. Fill WM gaps between CC and background (replace WM with background) + _fill_gaps_between_structures( + corrected_pred, cc_mask, background_mask, voxel_size, close_gap_size_mm, + fillable_labels={2, 41}, # only WM + description="between CC and background (WM → background)" + ) + + return corrected_pred + + +if __name__ == "__main__": + from FastSurferCNN.utils import nibabelImage + + # Command Line options are error checking done here + options = argument_parse() + + logging.setup_logging() + + logger.info(f"Reading inputs: {options.input_cc} {options.input_pred}...") + cc_seg_image = cast(nibabelImage, nib.load(options.input_cc)) + cc_seg_data = np.asanyarray(cc_seg_image.dataobj) + aseg_image = cast(nibabelImage, nib.load(options.input_pred)) + aseg_data = np.asanyarray(aseg_image.dataobj) + + def _is_conform(img, dtype, verbose): + return is_conform(img, vox_size=None, img_size=None, verbose=verbose, dtype=dtype) + + conform_args = (cc_seg_image, aseg_image), (np.uint8, np.integer) + conform_checks = list(thread_executor().map(partial(_is_conform, verbose=False), *conform_args)) + + if not all(conform_checks): + names = [] + dtypes = [] + for conform_check, img, dtype, name in zip(conform_checks, *conform_args, ("CC", "Prediction"), strict=True): + if not conform_check: + _is_conform(img, dtype, verbose=True) + names.append(name) + dtypes.append(dtype.name if hasattr(dtype, "name") else str(dtype)) + sys.exit( + f"Error: {' and '.join(names)} input image is not conformed (LIA orientation, {'/'.join(dtypes)} dtype). " + "Please conform the image(s) using the conform.py script." + ) + if not np.allclose(cc_seg_image.affine, aseg_image.affine): + sys.exit("Error: The affine matrices of the aseg and the corpus callosum images are not the same.") + + # Count initial labels before any modifications + initial_cc = np.sum(mask_in_array(aseg_data, SUBSEGMENT_LABELS)) + initial_fornix = np.sum(aseg_data == FORNIX_LABEL) + initial_wm = np.sum((aseg_data == 2) | (aseg_data == 41)) + initial_ventricles = np.sum((aseg_data == 4) | (aseg_data == 43)) + + # Paint CC into prediction (modifies aseg_data in place) + paint_in_cc(aseg_data, cc_seg_data) + + # Apply ventricle gap filling corrections + fornix_mask = cc_seg_data == FORNIX_LABEL + voxel_size = tuple(aseg_image.header.get_zooms()) + pred_corrected = correct_wm_ventricles(aseg_data, fornix_mask, voxel_size) + + logger.info(f"Writing segmentation with corpus callosum to: {options.output}") + pred_with_cc_fin = nib.MGHImage(pred_corrected, aseg_image.affine, aseg_image.header) + io_fut = thread_executor().submit(pred_with_cc_fin.to_filename, options.output) + + if options.aseg is not None: + rta_fut = thread_executor().submit( + reduce_to_aseg_and_save, + pred_corrected, + aseg_image.affine, + aseg_image.header, + options.aseg, + ) + else: + rta_fut = None + + # Count final labels + final_cc = np.sum(mask_in_array(pred_corrected, SUBSEGMENT_LABELS)) + final_fornix = np.sum(pred_corrected == FORNIX_LABEL) + final_wm = np.sum((pred_corrected == 2) | (pred_corrected == 41)) + final_ventricles = np.sum((pred_corrected == 4) | (pred_corrected == 43)) + + wm_change = final_wm - initial_wm + vent_change = final_ventricles - initial_ventricles + cc_change = final_cc - initial_cc + + logger.info(f"Changes: Corpus Callosum {'+' if cc_change >= 0 else ''}{cc_change}, " + f"White Matter {'+' if wm_change >= 0 else ''}{wm_change}, " + f"Ventricles {'+' if vent_change >= 0 else ''}{vent_change}") + + # Wait for all IO operations to complete + io_fut.result() + if rta_fut is not None: + rta_fut.result() + diff --git a/CorpusCallosum/segmentation/__init__.py b/CorpusCallosum/segmentation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/segmentation/inference.py b/CorpusCallosum/segmentation/inference.py new file mode 100644 index 000000000..9704b3b4b --- /dev/null +++ b/CorpusCallosum/segmentation/inference.py @@ -0,0 +1,310 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterator +from pathlib import Path +from typing import cast, overload + +import nibabel as nib +import numpy as np +import torch +from monai import transforms +from numpy import typing as npt + +from CorpusCallosum.data import constants +from CorpusCallosum.transforms.segmentation import CropAroundACPC +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML +from FastSurferCNN.download_checkpoints import load_checkpoint_config_defaults +from FastSurferCNN.download_checkpoints import main as download_checkpoints +from FastSurferCNN.models.networks import FastSurferVINN +from FastSurferCNN.utils import Image3d, Image4d, Shape2d, Shape3d, Shape4d, Vector2d, nibabelImage +from FastSurferCNN.utils.parallel import thread_executor + + +def load_model(device: torch.device | None = None) -> FastSurferVINN: + """Load trained model from checkpoint. + + Parameters + ---------- + device : torch.device or None, optional + Device to load model to, by default None. + If None, uses CUDA if available, else CPU. + + Returns + ------- + FastSurferVINN + Loaded and initialized model in evaluation mode. + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + params = { + "num_classes": 3, + "num_filters": 71, + "num_filters_interpol": 32, + "num_channels": 9, + "kernel_h": 3, + "kernel_w": 3, + "kernel_c": 1, + "stride_conv": 1, + "stride_pool": 2, + "pool": 2, + "height": 128, + "width": 128, + "base_res": 1.0, + "interpolation_mode": "bilinear", + "crop_position": "top_left", + "out_tensor_width": 320, + "out_tensor_height": 320, + } + model = FastSurferVINN(params) + + download_checkpoints(cc=True) + cc_config: dict[str, Path] = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + checkpoint_path = constants.FASTSURFER_ROOT / cc_config['segmentation'] + + weights = torch.load(checkpoint_path, weights_only=True, map_location=device) + model.load_state_dict(weights) + model.eval() + model.to(device) + return model + + +def run_inference( + model: "torch.nn.Module", + image_slice: Image3d, + ac_center: Vector2d, + pc_center: Vector2d, + voxel_size: tuple[float, float], + device: torch.device | None = None, + transform: transforms.Transform | None = None +) -> tuple[np.ndarray[Shape4d, np.dtype[np.int_]], Image4d, Image4d]: + """Run inference on a single image slice. + + Parameters + ---------- + model : torch.nn.Module + Trained model. + image_slice : np.ndarray + LIA-oriented input image as numpy array of shape (L, I, A). + ac_center : np.ndarray + Anterior commissure coordinates. + pc_center : np.ndarray + Posterior commissure coordinates. + voxel_size : a pair of floats + Voxel size of inferior/superior and anterior/posterior direction in mm. + device : torch.device, optional + Device to run inference on. If None, uses the device of the model. + transform : transforms.Transform, optional + Custom transform pipeline. + + Returns + ------- + seg_labels : npt.NDArray[int] + The segmentation result. + inputs : npt.NDArray[float] + The inputs to the model. + soft_labels : npt.NDArray[float] + The softlabel output. + """ + if device is None: + device = next(model.parameters()).device + + crop_around_acpc = CropAroundACPC(keys=['image'], padding_mm=35, random_translate=0) + to_discrete = transforms.AsDiscrete(argmax=True, to_onehot=3) + + # Preprocess slice + _inputs = torch.from_numpy(image_slice[:,None]) #,:256,:256]) # artifact from training script + sample = {'image': _inputs, 'AC_center': ac_center, 'PC_center': pc_center, 'res': np.asarray(voxel_size)} + sample_cropped = crop_around_acpc(sample) + _inputs, to_pad = sample_cropped['image'], sample_cropped['to_pad'] + _inputs = transforms.utils.rescale_array(_inputs, 0, 1, dtype=np.float32).to(device) + + # split into slices with 9 channels each + # Generate views with sliding window of 9 slices + batch_size, channels, height, width = _inputs.shape + _inputs = _inputs.unfold(0, 9, 1).swapdims(-1, 1).reshape(-1, 9*channels, height, width) + + # Post-process outputs + with torch.no_grad(): + scale_factors = torch.ones((_inputs.shape[0], 2), device=device) / torch.asarray([voxel_size], device=device) + + _logits = model(_inputs, scale_factor=scale_factors) + _softlabels = transforms.Activations(softmax=True, dim=1)(_logits) + + softlabels = _softlabels.cpu().numpy() + _labels = torch.stack([to_discrete(i) for i in _softlabels]) + + # Pad back to original size, to_pad is a tuple[int, int, int, int] + pad_tuples = ((0, 0),) * 2 + (to_pad[:2], to_pad[2:]) + labels = np.pad(_labels.cpu().numpy(), pad_tuples, mode='constant', constant_values=0) + softlabels = np.pad(softlabels, pad_tuples, mode='constant', constant_values=0) + + return tuple(x.transpose(0, 2, 3, 1) for x in (labels, _inputs.cpu().numpy(), softlabels)) + + +def load_validation_data( + path: str | Path, +) -> tuple[npt.NDArray[str], npt.NDArray[float], npt.NDArray[float], Iterator[int], npt.NDArray[str], list[str]]: + """Load validation data from CSV file and compute label widths. + + Reads a CSV file containing image paths, label paths, and AC/PC coordinates, + then computes the width (number of slices with non-zero labels) for each label file. + + Parameters + ---------- + path : str or Path + Path to the CSV file containing validation data. The CSV should have columns: + image, label, AC_center_x, AC_center_y, AC_center_z, + PC_center_x, PC_center_y, PC_center_z. + + Returns + ------- + images : npt.NDArray[str] + Array of image file paths. + ac_centers : npt.NDArray[float] + Array of anterior commissure coordinates (x, y, z). + pc_centers : npt.NDArray[float] + Array of posterior commissure coordinates (x, y, z). + label_widths : Iterator[int] + Iterator yielding the number of slices with non-zero labels for each label file. + labels : npt.NDArray[str] + Array of label file paths. + subj_ids : list[str] + List of subject IDs (from CSV index). + """ + import pandas as pd + + data = pd.read_csv(path, index_col=0, header=None) + data.columns = ["image", "label", "AC_center_x", "AC_center_y", "AC_center_z", + "PC_center_x", "PC_center_y", "PC_center_z"] + + ac_centers = data[["AC_center_x", "AC_center_y", "AC_center_z"]].values + pc_centers = data[["PC_center_x", "PC_center_y", "PC_center_z"]].values + images = data["image"].values + labels = data["label"].values + subj_ids = data.index.values.tolist() + + def _load(label_path: str | Path) -> int: + """Compute the width of non-zero slices in a label image. + + Parameters + ---------- + label_path : str or Path + Path to the label image file + + Returns + ------- + int + Number of slices containing non-zero labels, or total slices if <= 100 + """ + label_img = cast(nibabelImage, nib.load(label_path)) + + if label_img.shape[0] > 100: + # check which slices have non-zero values + label_data = np.asarray(label_img.dataobj) + non_zero_slices = np.any(label_data > 0, axis=(1,2)) + first_nonzero = np.argmax(non_zero_slices) + last_nonzero = len(non_zero_slices) - np.argmax(non_zero_slices[::-1]) + return last_nonzero - first_nonzero + else: + return label_img.shape[0] + + label_widths = thread_executor().map(_load, data["label"]) + + return images, ac_centers, pc_centers, label_widths, labels, subj_ids + +@overload +def one_hot_to_label(one_hot: Image4d, label_ids: list[int] | None = None) \ + -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... + +@overload +def one_hot_to_label(one_hot: Image3d, label_ids: list[int] | None = None) \ + -> np.ndarray[Shape2d, np.dtype[np.int_]]: ... + +def one_hot_to_label( + one_hot: np.ndarray[tuple[int, ...], np.dtype[np.bool_]], + label_ids: list[int] | None = None, +) -> np.ndarray[tuple[int, ...], np.dtype[np.int_]]: + """Convert one-hot encoded segmentation to label map. + + Converts a one-hot encoded segmentation array to discrete labels by taking + the argmax along the last axis and optionally mapping to specific label values. + + Parameters + ---------- + one_hot : np.ndarray of floats + One-hot encoded segmentation array of shape (..., num_classes). + label_ids : array_like of ints, optional + List of label IDs to map classes to. If None, defaults to [0, FORNIX_LABEL, CC_LABEL]. + The index in this list corresponds to the class index from argmax. + + Returns + ------- + npt.NDArray[int] + Label map with discrete integer labels. + """ + if label_ids is None: + from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL + label_ids = [0, CC_LABEL, FORNIX_LABEL] + + label = np.argmax(one_hot, axis=3) + if label_ids is not None: + label = np.asarray(label_ids)[label] + + return label + + +def run_inference_on_slice( + model: "torch.nn.Module", + test_slab: Image3d, + ac_center: Vector2d, + pc_center: Vector2d, + voxel_size: tuple[float, float], +) -> tuple[np.ndarray[Shape3d, np.dtype[np.int_]], Image4d, Image4d]: + """Run inference on a single slice. + + Parameters + ---------- + model : torch.nn.Module + Trained model for inference. + test_slab : np.ndarray + Input image slice. + ac_center : npt.NDArray[float] + Anterior commissure coordinates (Inferior and Anterior values). + pc_center : npt.NDArray[float] + Posterior commissure coordinates (Inferior and Posterior values). + voxel_size : a pair of floats + Voxel sizes in superior/inferior and anterior/posterior direction in mm. + + Returns + ------- + results: np.ndarray + Label map after one-hot conversion. + inputs: np.ndarray + Preprocessed input image. + outputs_soft: npt.NDArray[float] + Softlabel outputs (non-discrete). + + """ + # add zero in front of AC_center and PC_center + ac_center = np.concatenate([np.zeros(1), ac_center]) + pc_center = np.concatenate([np.zeros(1), pc_center]) + + _results, inputs, outputs_soft = run_inference(model, test_slab, ac_center, pc_center, voxel_size) + results = one_hot_to_label(_results) + + return results, inputs, outputs_soft diff --git a/CorpusCallosum/segmentation/segmentation_postprocessing.py b/CorpusCallosum/segmentation/segmentation_postprocessing.py new file mode 100644 index 000000000..a0b6e2729 --- /dev/null +++ b/CorpusCallosum/segmentation/segmentation_postprocessing.py @@ -0,0 +1,459 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TypeVar + +import numpy as np +from scipy import ndimage +from scipy.spatial.distance import cdist +from skimage.measure import label +from torchgen.model import ScalarType + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import CC_LABEL +from CorpusCallosum.utils.types import Points3dType +from FastSurferCNN.utils import Mask3d, Shape3d, ShapeType, Vector3d + +logger = logging.get_logger(__name__) + +ArrayType = TypeVar('ArrayType', bound=np.ndarray) + + +def find_component_boundaries(labels_arr: np.ndarray[ShapeType, np.dtype[ScalarType]], component_id: int) \ + -> np.ndarray[ShapeType, np.dtype[np.integer]]: + """Find boundary voxels of a connected component. + + Parameters + ---------- + labels_arr : np.ndarray + Labeled array from connected components analysis. + component_id : int + ID of the component to find boundaries for. + + Returns + ------- + np.ndarray + Array of shape (N, 3) containing boundary coordinates. + + Notes + ----- + Uses 6-connectivity (face neighbors only) to determine boundaries. + Boundary voxels are those that are part of the component but have + at least one non-component neighbor. + """ + component_mask = labels_arr == component_id + + # Create a structuring element for 6-connectivity (face neighbors only) + struct = ndimage.generate_binary_structure(3, 1) + + # Erode the component to find internal voxels + eroded = ndimage.binary_erosion(component_mask, structure=struct) + + # Boundary is the difference between original and eroded + boundary = component_mask & ~eroded + + return np.array(np.where(boundary)).T + + +def find_minimal_connection_path( + boundary_coords1: Points3dType, + boundary_coords2: Points3dType, + max_distance: float = 3.0 +) -> tuple[Vector3d, Vector3d] | None: + """Find the minimal connection path between two component boundaries. + + Parameters + ---------- + boundary_coords1 : np.ndarray + Boundary coordinates of first component, shape (N1, 3). + boundary_coords2 : np.ndarray + Boundary coordinates of second component, shape (N2, 3). + max_distance : float, default=3.0 + Maximum distance to consider for connection, by default 3.0. + + Returns + ------- + tuple[np.ndarray, np.ndarray] or None + If a valid connection is found: + + - point1 : Coordinates on first boundary + - point2 : Coordinates on second boundary + + None if no connection within max_distance is found. + + Notes + ----- + Uses Euclidean distance to find the closest pair of points + between the two boundaries. + """ + if len(boundary_coords1) == 0 or len(boundary_coords2) == 0: + return None + + # Calculate pairwise distances between all boundary points + distances = cdist(boundary_coords1, boundary_coords2, metric='euclidean') + + # Find the minimum distance and corresponding points + min_idx = np.unravel_index(np.argmin(distances), distances.shape) + min_distance = distances[min_idx] + + if min_distance <= max_distance: + return boundary_coords1[min_idx[0]], boundary_coords2[min_idx[1]] + + return None + + +def create_connection_line(point1: Vector3d, point2: Vector3d) -> list[tuple[int, int, int]]: + """Create a line of voxels connecting two points. + + Uses a simplified 3D line algorithm to create a sequence of voxels + that form a continuous path between the two points. + + Parameters + ---------- + point1 : np.ndarray + Starting point coordinates, shape (3,). + point2 : np.ndarray + Ending point coordinates, shape (3,). + + Returns + ------- + list of int triplets + List of (x, y, z) coordinates forming the connection line. + + Notes + ----- + The line is created by interpolating between the points using + the maximum distance in any dimension as the number of steps. + """ + x1, y1, z1 = map(int, point1) + x2, y2, z2 = map(int, point2) + + line_points: list[tuple[int, int, int]] = [] + + # Calculate the number of steps needed + dx = abs(x2 - x1) + dy = abs(y2 - y1) + dz = abs(z2 - z1) + + steps = max(dx, dy, dz) + + if steps == 0: + return [(x1, y1, z1)] + + # Calculate increments for each dimension + x_inc = (x2 - x1) / steps + y_inc = (y2 - y1) / steps + z_inc = (z2 - z1) / steps + + # Generate points along the line + for i in range(steps + 1): + x = int(round(x1 + i * x_inc)) + y = int(round(y1 + i * y_inc)) + z = int(round(z1 + i * z_inc)) + line_points.append((x, y, z)) + + return line_points + + +def connect_nearby_components(seg_arr: ArrayType, max_connection_distance: float = 3.0, plot: bool = False) \ + -> ArrayType: + """Connect nearby disconnected components that should be connected. + + This function identifies disconnected components in the segmentation and creates + minimal connections between components that are close to each other. + + Parameters + ---------- + seg_arr : np.ndarray + Input binary segmentation array. + max_connection_distance : float, default=3.0 + Maximum distance to connect components. + plot : bool, default=False + Whether to plot the segmentation with connected components. + + Returns + ------- + np.ndarray + Segmentation array with minimal connections added between nearby components. + + Notes + ----- + The function: + 1. Identifies connected components in the input segmentation + 2. Finds boundaries between components + 3. Creates minimal connections between nearby components + 4. Returns the modified segmentation with added connections + """ + + # Create a copy to modify + connected_seg = seg_arr.copy() + + # Find connected components without dilation first + labels_cc = label(seg_arr, connectivity=3, background=0) + + # Get component sizes (excluding background) + bincount = np.bincount(labels_cc.flat) + component_ids = np.where(bincount > 0)[0][1:] # Exclude background (0) + + if len(component_ids) <= 1: + return connected_seg # Only one component, no connections needed + + # Sort components by size (largest first) + component_sizes = [(comp_id, bincount[comp_id]) for comp_id in component_ids] + component_sizes.sort(key=lambda x: x[1], reverse=True) + + # Use the largest component as the reference + main_component_id = component_sizes[0][0] + + logger.info(f"Found {len(component_ids)} disconnected components. " + f"Attempting to connect smaller components to main component (size: {component_sizes[0][1]})") + + connections_made = 0 + + # Try to connect each smaller component to the main component + for comp_id, comp_size in component_sizes[1:]: + if comp_size < 5: # Skip very small components (likely noise) + logger.debug(f"Skipping tiny component {comp_id} with size {comp_size}") + continue + + # Find boundaries of both components + main_boundary = find_component_boundaries(labels_cc, main_component_id) + comp_boundary = find_component_boundaries(labels_cc, comp_id) + + # Find minimal connection path + connection = find_minimal_connection_path(main_boundary, comp_boundary, max_connection_distance) + + if connection is not None: + point1, point2 = connection + distance = np.linalg.norm(point2 - point1) + + logger.debug(f"Connecting component {comp_id} (size: {comp_size}) to main component. " + f"Distance: {distance:.2f} voxels") + + # Create connection line + connection_line: list[tuple[int, int, int]] = create_connection_line(point1, point2) + + # Add connection voxels to the segmentation + # Use the same label as the original segmentation at the connection points + connection_label = seg_arr[point1[0], point1[1], point1[2]] if \ + seg_arr[point1[0], point1[1], point1[2]] != 0 else \ + seg_arr[point2[0], point2[1], point2[2]] + + for x, y, z in connection_line: + if (0 <= x < connected_seg.shape[0] and + 0 <= y < connected_seg.shape[1] and + 0 <= z < connected_seg.shape[2]): + if connected_seg[x, y, z] == 0: # Only fill empty voxels + connected_seg[x, y, z] = connection_label + + connections_made += 1 + else: + logger.debug(f"Component {comp_id} (size: {comp_size}) too far from main component") + + logger.info(f"Created {connections_made} minimal connections between components") + + + # Plot components for debugging + if plot: + import matplotlib + import matplotlib.pyplot as plt + curr_backend = matplotlib.get_backend() + plt.switch_backend("qtagg") + n_components = len(component_sizes) + fig, axes = plt.subplots(1, n_components + 1, figsize=(5*(n_components + 1), 5)) + if n_components == 1: + axes = [axes] + # Plot each component in a different color + for i, (comp_id, comp_size) in enumerate(component_sizes): + component_mask = labels_cc == comp_id + axes[i].imshow(component_mask[component_mask.shape[0]//2], cmap='gray') + axes[i].set_title(f'Component {comp_id}\nSize: {comp_size}') + axes[i].axis('off') + + # Plot the connected segmentation + axes[-1].imshow(connected_seg[connected_seg.shape[0]//2], cmap='gray') + axes[-1].set_title('Connected Segmentation') + axes[-1].axis('off') + plt.tight_layout() + plt.show() + plt.switch_backend(curr_backend) + + return connected_seg + + +def get_cc_volume_voxel( + desired_width_mm: int, + cc_mask: Mask3d, + voxel_size: tuple[float, float, float], +) -> float: + """Calculate the volume of the corpus callosum in cubic millimeters. + + This function calculates the volume of the corpus callosum (CC) in cubic millimeters. + If the CC width is larger than desired_width_mm, the voxels on the edges are calculated as + partial volumes to achieve the desired width. + + Parameters + ---------- + desired_width_mm : int + Desired width of the CC in millimeters. + cc_mask : np.ndarray + Binary mask of the corpus callosum in LIA orientation. + voxel_size : triplet of floats + LIA-oriented Voxel size in millimeters (x, y, z). + + Returns + ------- + float + Volume of the CC in cubic millimeters. + + Raises + ------ + ValueError + If CC width is smaller than desired width + AssertionError + If CC mask doesn't have odd number of voxels in x dimension + + Notes + ----- + The function assumes LIA orientation + """ + # Get the bounding box of the CC mask in x dimension + any_cc = np.any(cc_mask, axis=(1, 2)) + if not np.any(any_cc): + return 0.0 + + first_x = np.argmax(any_cc) + last_x = len(any_cc) - 1 - np.argmax(any_cc[::-1]) + + # Crop mask to its extent in x + cropped_mask = cc_mask[first_x : last_x + 1] + width_vox = cropped_mask.shape[0] + + assert width_vox % 2 == 1, f"CC mask must have odd number of voxels in x dimension, but has {width_vox}" + + # Calculate voxel volume + voxel_volume: float = np.prod(voxel_size, dtype=float) + voxel_width: float = voxel_size[0] + + # we are in LIA, so 0 is L/R resolution + width_mm = width_vox * voxel_width + + if width_mm == desired_width_mm: + return np.sum(cropped_mask) * voxel_volume + elif width_mm > desired_width_mm: + # remainder on the left/right side of the CC mask + desired_width_vox = desired_width_mm / voxel_width + + # The number of full voxels in the center is (width_vox - 2) + # The remaining width must be covered by the two edge voxels. + fraction_of_voxel_at_edge = (desired_width_vox - (width_vox - 2)) / 2 + + left_partial_volume = np.sum(cropped_mask[0]) * voxel_volume * fraction_of_voxel_at_edge + right_partial_volume = np.sum(cropped_mask[-1]) * voxel_volume * fraction_of_voxel_at_edge + center_volume = np.sum(cropped_mask[1:-1]) * voxel_volume + return left_partial_volume + right_partial_volume + center_volume + else: + raise ValueError(f"Width of CC segmentation is smaller than desired width: {width_mm} < {desired_width_mm}") + + + +def extract_largest_connected_component( + seg_arr: Mask3d, + max_connection_distance: float = 3.0, +) -> Mask3d: + """Get the largest connected component from a binary segmentation array. + + Parameters + ---------- + seg_arr : np.ndarray + Input binary segmentation array. + max_connection_distance : float, optional + Maximum distance to connect components, by default 3.0. + + Returns + ------- + np.ndarray + Binary mask of the largest connected component. + + Notes + ----- + The function first attempts to connect nearby disconnected components + that should be connected, then finds the largest connected component. + It uses minimal connections between close components before falling + back to dilation if no connections are made. + """ + # First attempt: try to connect nearby components with minimal connections + connected_seg = connect_nearby_components(seg_arr, max_connection_distance) + + # Check if connections were successful by comparing connectivity + original_labels = label(seg_arr, connectivity=3, background=0) + connected_labels = label(connected_seg, connectivity=3, background=0) + + original_components = len(np.unique(original_labels)) - 1 # Exclude background + connected_components = len(np.unique(connected_labels)) - 1 # Exclude background + + if connected_components < original_components: + logger.info(f"Successfully reduced components from {original_components} to {connected_components} " + "using minimal connections") + mask = connected_seg + + # Get connected components from the processed mask + labels_cc = label(mask, connectivity=3, background=0) + + # Get component counts + bincount = np.bincount(labels_cc.flat) + + # Get background label (assumed to be the largest component) + background = np.argmax(bincount) + bincount[background] = -1 + + # Get largest connected component + largest_cc = np.equal(labels_cc, np.argmax(bincount)) + + return largest_cc + + +def clean_cc_segmentation( + seg_arr: np.ndarray[Shape3d, np.dtype[np.int_]], + max_connection_distance: float = 3.0, +) -> tuple[np.ndarray[Shape3d, np.dtype[np.int_]], Mask3d]: + """Clean corpus callosum segmentation by removing non-connected components. + + Parameters + ---------- + seg_arr : npt.NDArray[int] + Input segmentation array with CC (192) and fornix (250) labels. + max_connection_distance : float, default=3.0 + Maximum distance to connect components. + + Returns + ------- + clean_seg : np.NDArray[int] + Cleaned segmentation array with only the largest connected component of CC and fornix. + mask : npt.NDArray[bool] + Binary mask of the largest connected component. + + """ + from functools import partial + + extract_largest = partial(extract_largest_connected_component, max_connection_distance=max_connection_distance) + + # Remove non-connected components from the CC alone, with minimal connections + mask = np.equal(seg_arr, CC_LABEL) + cc_seg = mask.astype(int) * CC_LABEL + cc_label_cleaned = np.concatenate([extract_largest(seg[None]) * CC_LABEL for seg in cc_seg], axis=0) + + # Add fornix to the CC labels + clean_seg = np.where(mask, cc_label_cleaned, seg_arr) + + return clean_seg, np.greater(cc_label_cleaned, 0) diff --git a/CorpusCallosum/shape/__init__.py b/CorpusCallosum/shape/__init__.py new file mode 100644 index 000000000..4950a2427 --- /dev/null +++ b/CorpusCallosum/shape/__init__.py @@ -0,0 +1,15 @@ + +from CorpusCallosum.shape import endpoint_heuristic, mesh, metrics, postprocessing, subsegment_contour, thickness +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.mesh import CCMesh + +__all__ = [ + "CCContour", + "CCMesh", + "endpoint_heuristic", + "mesh", + "metrics", + "postprocessing", + "subsegment_contour", + "thickness", +] diff --git a/CorpusCallosum/shape/contour.py b/CorpusCallosum/shape/contour.py new file mode 100644 index 000000000..7044ef1e5 --- /dev/null +++ b/CorpusCallosum/shape/contour.py @@ -0,0 +1,885 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module provides the ``CCContour`` class for reading, writing, and +manipulating 2D corpus callosum contours together with per-vertex thickness +values. Typical template outputs (from ``fastsurfer_cc.py --save_template``) +emit one set per slice: + +- ``contour_.txt``: CSV with header ``New contour, anterior_endpoint_idx=, posterior_endpoint_idx=

`` followed + by ``x,y`` rows. +- ``thickness_values_.txt``: CSV with header ``thickness`` and one value per contour vertex. +- ``thickness_measurement_points_.txt``: CSV with header ``vertex_idx`` listing the vertices where thickness was + measured. +""" + +import re +from pathlib import Path +from typing import TypeVar + +import lapy +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import scipy.interpolate +from scipy.ndimage import gaussian_filter1d + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.shape.endpoint_heuristic import find_cc_endpoints, smooth_contour +from CorpusCallosum.shape.thickness import cc_thickness +from CorpusCallosum.utils.types import Points2dType +from FastSurferCNN.utils import AffineMatrix4x4, Mask2d, Vector2d + +logger = logging.get_logger(__name__) + +Self = TypeVar("Self", bound="CCContour") + + +# FIXME: Maybe CCContur should inherit from Polygon at a later date? +class CCContour: + """A class for representing and manipulating corpus callosum (CC) contours. + + This class provides functionality for manipulating and analyzing corpus callosum contours. + + Attributes + ---------- + points : np.ndarray + Array of shape (N, 2) containing 2D contour points. + thickness_values : np.ndarray + Array of shape (N,) for thickness measurements for each contour point. + endpoint_idxs : tuple[int, int] + Tuple containing start and end indices for the contour. + + Examples + -------- + >>> from CorpusCallosum.shape.contour import CCContour + >>> + >>> contour = CCContour(contour_points, thickness_values, + >>> endpoint_idxs=(anterior_idx, posterior_idx), + >>> z_position=0.0) + >>> contour.fill_thickness_values() # interpolate missing values + >>> contour.smooth_contour(window_size=5) + >>> contour.save_contour("contour_0.txt") + >>> contour.save_thickness_values("thickness_values_0.txt") + """ + + def __init__( + self, + points: Points2dType, + thickness_values: np.ndarray[tuple[int], np.dtype[np.float_]] | None, + endpoint_idxs: tuple[int, int] | None = None, + z_position: float = 0.0 + ): + """Initialize a CCContour object. + + Parameters + ---------- + points : np.ndarray + Array of shape (N, 2) containing 2D contour points. + thickness_values : np.ndarray, optional + Array of thickness measurements for each contour point. + endpoint_idxs : tuple[int, int], optional + Tuple containing start and end indices for the contour. + z_position : float, default=0.0 + The distance of the slice from midslice. + """ + self.points = points + if self.points.ndim != 2 or self.points.shape[1] != 2: + raise ValueError(f"Contour must be a (N, 2) array, but is {self.points.shape}") + self.thickness_values = thickness_values + if thickness_values is not None: + if self.points.shape[0] != len(thickness_values): + raise ValueError( + f"Number of contour points ({self.points.shape[0]}) does not match number of thickness values " + f"({len(thickness_values)})", + ) + # write vertex indices where thickness values are not nan + self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] + else: + self.original_thickness_vertices = None + self.z_position = z_position + + if endpoint_idxs is None: + self.endpoint_idxs = (0, len(points) // 2) + else: + self.endpoint_idxs = endpoint_idxs + + def __len__(self) -> int: + """Return the number of points on the contour.""" + return len(self.points) + + @property + def area(self) -> float: + """Calculate the area of the contour using the shoelace formula. + + Returns + ------- + float + The area of the contour. + """ + if len(self.points) < 3: + return 0.0 + x = self.points[:, 0] + y = self.points[:, 1] + return 0.5 * np.abs(np.sum(x * np.roll(y, -1) - np.roll(x, -1) * y)) + + def smooth_contour(self, window_size: int = 5) -> None: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + window_size : int, default=5 + Size of the smoothing window. + + Notes + ----- + Uses smooth_contour from cc_endpoint_heuristic module. + """ + self.points = np.array(smooth_contour(*self.points.T, window_size=window_size)).T + + def copy(self) -> "CCContour": + """Copy the contour. + """ + return CCContour( + self.points.copy(), + self.thickness_values.copy() if self.thickness_values is not None else None, + self.endpoint_idxs, + self.z_position + ) + + def get_contour_edge_lengths(self) -> np.ndarray: + """Get the lengths of the edges of a contour. + + Returns + ------- + np.ndarray + Array of edge lengths for the contour. + + Notes + ----- + Edge lengths are calculated as Euclidean distances between consecutive points + in the contour, including the edge closing the loop between the last and + first point. + """ + edges = np.roll(self.points, -1, axis=0) - self.points + return np.sqrt(np.sum(edges**2, axis=1)) + + def create_levelpaths( + self, + num_points: int, + inplace: bool = False + ) -> tuple[list[np.ndarray], float, float, np.ndarray, np.ndarray, tuple[int, int], float]: + """Calculate thickness and level paths for the CC contour using Laplace equation. + + Parameters + ---------- + num_points : int + Number of points for thickness estimation. + inplace : bool, default=True + Whether to update the contour points and thickness values in place. + + Returns + ------- + levelpaths : list[np.ndarray] + List of level paths across the CC. + thickness : float + Mean thickness of the CC. + midline_len : float + Length of the CC midline. + midline_equi : np.ndarray + Equidistant points along the midline. + contour_with_thickness : np.ndarray + Contour points with thickness information, shape (N, 3). + endpoint_idxs : tuple[int, int] + Indices of the anterior and posterior endpoints on the updated contour. + curvature : float + Mean curvature of the midline. + """ + + # FIXME: cache all these values in CCContour, and invalidate the cache, when either points or endpoint_idxs get + # changed; alternatively, make points and endpoint_idxs read_only (by creating getter-only properties) + # and have all functions that change points or endpoints return a new CCContour object instead. + + midline_len, thickness, curvature, midline_equi, levelpaths, contour_with_thickness, endpoint_idxs = \ + cc_thickness( + self.points, + self.endpoint_idxs, + n_points=num_points, + ) + + if inplace: + self.points = contour_with_thickness[:, :2] + self.thickness_values = contour_with_thickness[:, 2] + self.original_thickness_vertices = np.where(~np.isnan(self.thickness_values))[0] + self.endpoint_idxs = endpoint_idxs + + return levelpaths, thickness, midline_len, midline_equi, contour_with_thickness, endpoint_idxs, curvature + + def set_thickness_values(self, thickness_values: np.ndarray, use_measurement_points: bool = False) -> None: + """Set the thickness values for the contour. + This is useful to update the thickness values for specific plots. + + Parameters + ---------- + thickness_values : np.ndarray + Array of thickness values for the contour. + use_measurement_points : bool, default=False + Whether to use the measurement points to set the thickness values. + """ + if use_measurement_points: + if self.original_thickness_vertices is None: + if len(thickness_values) != len(self.points): + raise ValueError( + f"Thickness values not initialized and number of points in the contour {len(self.points)} does " + f"not match number of thickness values {len(thickness_values)}.", + ) + self.original_thickness_vertices = np.where(~np.isnan(thickness_values))[0] + self.thickness_values = thickness_values + elif len(thickness_values) == len(self.original_thickness_vertices): + self.thickness_values = np.full(len(self.points), np.nan) + self.thickness_values[self.original_thickness_vertices] = thickness_values + else: + raise ValueError( + "Number of thickness values does not match number of measurement points " + f"{len(self.original_thickness_vertices)}.", + ) + else: + if len(thickness_values) != len(self.points): + raise ValueError( + f"The number of thickness values does not match number of points in the contour " + f"{len(self.points)}.", + ) + self.thickness_values = thickness_values + + def fill_thickness_values(self) -> None: + """Interpolate missing thickness values using weighted averaging. + + Notes + ----- + The function: + 1. Processes each contour with missing thickness values. + 2. For each missing value: + - Finds two closest points with known thickness. + - Calculates distances along contour. + - Computes weighted average based on inverse distance. + 3. Updates thickness values in place. + + The weights are calculated as inverse distances to ensure closer + points have more influence on the interpolated value. + + """ + thickness = self.thickness_values + edge_lengths = self.get_contour_edge_lengths() + + # Find indices of points with known thickness + known_idx = np.where(~np.isnan(thickness))[0] + + if len(known_idx) == 0: + logger.warning("No known thickness values; skipping interpolation") + return + if len(known_idx) == 1: + logger.warning("Only one known thickness value; skipping interpolation") + thickness[np.isnan(thickness)] = thickness[known_idx[0]] + self.thickness_values = thickness + return + + # For each point with unknown thickness + total_length = np.sum(edge_lengths) + cumulative_lengths = np.concatenate(([0], np.cumsum(edge_lengths))) + + for j in range(len(thickness)): + if not np.isnan(thickness[j]): + continue + + # Find two closest points with known thickness + distances = np.zeros(len(known_idx)) + for k, idx in enumerate(known_idx): + # Calculate distance along contour by summing edge lengths + # in both directions and taking the minimum + if idx > j: + dist_forward = cumulative_lengths[idx] - cumulative_lengths[j] + else: + dist_forward = cumulative_lengths[j] - cumulative_lengths[idx] + + distances[k] = min(dist_forward, total_length - dist_forward) + + # Get indices of two closest points + closest_indices = known_idx[np.argsort(distances)[:2]] + closest_distances = np.sort(distances)[:2] + + # Calculate weights based on inverse distance + weights = 1.0 / closest_distances + weights = weights / np.sum(weights) + + # Calculate weighted average thickness + thickness[j] = np.sum(weights * thickness[closest_indices]) + + self.thickness_values = thickness + + def smooth_thickness_values(self, iterations: int = 1) -> None: + """Smooth the thickness values using a Gaussian filter. + + Parameters + ---------- + iterations : int, optional + Number of smoothing iterations, by default 1. + + Notes + ----- + Applies Gaussian smoothing with sigma=5 to thickness values + along the contour. + """ + if self.thickness_values is not None: + # Handle NaN values by interpolating if necessary or just smoothing the non-NaN parts + # Here we assume they might have been filled already by fill_thickness_values + for _ in range(iterations): + self.thickness_values = gaussian_filter1d(self.thickness_values, sigma=5, mode="wrap") + + def plot_contour(self, output_path: str | None = None) -> None: + """Plot a single contour with thickness values. + + Parameters + ---------- + output_path : str + Path where to save the plot. + + Notes + ----- + Creates a 2D visualization with: + - Points colored by thickness values. + - Gray points for missing thickness values. + - Connected contour line. + - Grid, labels, and legend. + """ + if output_path is not None: + self.__make_parent_folder(output_path) + + plt.figure(figsize=(10, 10)) + + # Plot points with colors based on thickness + gray_points = np.isnan(self.thickness_values) + if np.any(gray_points): + plt.scatter(self.points[gray_points, 0], self.points[gray_points, 1], color="gray", s=1) + + if not np.all(gray_points): + not_gray = np.logical_not(gray_points) + # Map thickness to color from red to yellow + norm_thickness = self.thickness_values[not_gray] / np.nanmax(self.thickness_values[not_gray]) + color_values = plt.cm.YlOrRd(norm_thickness) + plt.scatter(self.points[not_gray, 0], self.points[not_gray, 1], c=color_values, s=1) + + # Connect points with lines + plt.plot(self.points[:, 0], self.points[:, 1], "-", color="black", alpha=0.3, label="Contour") + plt.axis("equal") + plt.xlabel("X") + plt.ylabel("Y") + plt.title("CC contour") + plt.legend() + plt.grid(True) + plt.tight_layout() + if output_path is not None: + plt.savefig(output_path, dpi=300) + else: + plt.show() + + def plot_contour_colorfill( + self, + plot_values: np.ndarray, + title: str | None = None, + save_path: str | None = None, + colorbar: bool = True, + mode: str = "p-value", + ) -> matplotlib.figure.Figure: + """Plot a contour with levelset visualization. + + Creates a visualization of a contour with interpolated levelsets, useful for + analyzing the thickness distribution across the corpus callosum. + + Parameters + ---------- + plot_values : np.ndarray + Array of values to plot on CC from anterior to posterior (left to right in the plot). + title : str, optional + Title for the plot. + save_path : str, optional + Path to save the plot. If None, displays interactively. + colorbar : bool, default=True + Whether to show the colorbar. + mode : {"p-value", "icc", "thickness"}, default="p-value" + Mode of the plot. + + Returns + ------- + matplotlib.figure.Figure + The created figure object. + """ + plot_values = plot_values[::-1] # make sure values are plotted left to right (anterior to posterior) + + levelpaths, *_ = self.create_levelpaths(num_points=len(plot_values)-1, inplace=False) + + outside_contour = self.points.T + + # Create a grid of points covering the contour area with higher resolution + x_min, x_max = np.min(outside_contour[0]), np.max(outside_contour[0]) + y_min, y_max = np.min(outside_contour[1]), np.max(outside_contour[1]) + margin = 1 + resolution = 0.05 # Higher resolution for smoother interpolation + x_grid, y_grid = np.meshgrid( + np.arange(x_min - margin, x_max + margin, resolution), np.arange(y_min - margin, y_max + margin, resolution) + ) + + # Create a path from the outside contour + contour_path = matplotlib.path.Path(np.column_stack([outside_contour[0], outside_contour[1]])) + + # Check which points are inside the contour + points = np.column_stack([x_grid.flatten(), y_grid.flatten()]) + mask = contour_path.contains_points(points).reshape(x_grid.shape) + + # Collect all levelpath points and their corresponding values + # Extend each levelpath at both ends to improve extrapolation + all_level_points_x = [] + all_level_points_y = [] + all_level_values = [] + + for i, path in enumerate(levelpaths): + + # add third dimension to path + path = np.column_stack([path, np.zeros(len(path))]) + + if len(path) == 1: + all_level_points_x.append(path[0][0]) + all_level_points_y.append(path[0][1]) + all_level_values.append(plot_values[i]) + continue + + # make levelpath + path = lapy.Polygon(path).resample(1000).points + + # Extend at the beginning: add point in direction opposite to first segment + first_segment = path[1] - path[0] + # standardize length of first segment + first_segment = first_segment / np.linalg.norm(first_segment) * 10 + extension_start = path[0] - first_segment + all_level_points_x.append(extension_start[0]) + all_level_points_y.append(extension_start[1]) + all_level_values.append(plot_values[i]) + + # Add original path points + for point in path: + all_level_points_x.append(point[0]) + all_level_points_y.append(point[1]) + all_level_values.append(plot_values[i]) + + # Extend at the end: add point in direction of last segment + last_segment = path[-1] - path[-2] + # standardize length of last segment + last_segment = last_segment / np.linalg.norm(last_segment) * 10 + extension_end = path[-1] + last_segment + all_level_points_x.append(extension_end[0]) + all_level_points_y.append(extension_end[1]) + all_level_values.append(plot_values[i]) + + # Convert to numpy arrays + all_level_points_x = np.array(all_level_points_x) + all_level_points_y = np.array(all_level_points_y) + all_level_values = np.array(all_level_values) + + # Use griddata to perform smooth interpolation - using 'linear' instead of 'cubic' + # and properly formatting the input points + grid_values = scipy.interpolate.griddata( + (all_level_points_x, all_level_points_y), all_level_values, (x_grid, y_grid), method="linear", fill_value=0, + ) + + # smooth the grid_values + grid_values = scipy.ndimage.gaussian_filter(grid_values, sigma=5, radius=5) + + # Apply the mask to only show values inside the contour + masked_values = np.where(mask, grid_values, np.nan) + + if mode == "p-value": + # Sample colormaps + colors1 = plt.cm.binary([0.4] * 128) + colors2 = plt.cm.hot(np.linspace(0.8, 0.1, 128)) + elif mode == "icc": + colors1 = plt.cm.Blues(np.linspace(0, 1, 128)) + colors2 = plt.cm.binary([0.4] * 128) + elif mode == "thickness": + # Blue to red colormap for thickness values + cmap = plt.cm.coolwarm + else: + raise ValueError(f"Invalid mode '{mode}'") + + # Combine the color samples for p-value and icc modes + if mode != "thickness": + colors = np.vstack((colors2, colors1)) + # Create a new colormap + cmap = matplotlib.colors.LinearSegmentedColormap.from_list("my_colormap", colors) + + # Plot CC contour with levelsets + fig = plt.figure(figsize=(10, 3)) + # Apply a 10-degree rotation to the entire plot + base = plt.gca().transData + transform = matplotlib.transforms.Affine2D().rotate_deg(10) + transform = transform + base + + # Plot the filled contour with interpolated colors + plt.imshow( + masked_values, + extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), + origin="lower", + cmap=cmap, + alpha=1, + interpolation="bilinear", + vmin=0 if mode != "thickness" else np.nanmin(plot_values), + vmax=0.10 if mode == "p-value" else (1 if mode == "icc" else np.nanmax(plot_values)), + transform=transform, + ) + + plt.imshow( + masked_values, + extent=(x_min - margin, x_max + margin, y_min - margin, y_max + margin), + origin="lower", + cmap=cmap, + alpha=1, + interpolation="bilinear", + vmin=0 if mode != "thickness" else np.nanmin(plot_values), + vmax=0.10 if mode == "p-value" else (1 if mode == "icc" else np.nanmax(plot_values)), + # norm=LogNorm(vmin=1e-3, vmax=0.1), # Set minimum to avoid log(0) + transform=transform, + ) + + if colorbar: + # Add a colorbar + cbar = plt.colorbar(aspect=15) + if mode == "p-value": + cbar.ax.set_ylim(0.001, 0.054) + cbar.ax.set_yticks([0.0, 0.01, 0.02, 0.03, 0.04, 0.05]) + cbar.set_label("p-value (log scale)") + elif mode == "icc": + cbar.ax.set_ylim(0, 1) + cbar.ax.set_yticks([0, 0.25, 0.5, 0.75, 1]) + cbar.ax.set_label("Intraclass correlation coefficient") + elif mode == "thickness": + # Set limits based on actual thickness values + thickness_min = np.nanmin(plot_values) + thickness_max = np.nanmax(plot_values) + cbar.ax.set_ylim(thickness_min, thickness_max) + cbar.set_label("Thickness (mm)") + + # Plot the outside contour on top for clear boundary + plt.plot(outside_contour[0], outside_contour[1], "k-", linewidth=2, label="CC Contour", transform=transform) + + plt.axis("equal") + plt.title(title, fontsize=14, fontweight="bold") + # plt.legend(loc='best') + plt.gca().invert_xaxis() + plt.axis("off") + if save_path is not None: + self.__make_parent_folder(save_path) + plt.savefig(save_path, dpi=300) + else: + plt.show() + return fig + + @staticmethod + def __make_parent_folder(filename: Path | str) -> None: + """Create the parent folder for a file if it doesn't exist. + + Parameters + ---------- + filename : Path, str + Path to the file whose parent folder should be created. + + Notes + ----- + Creates parent directory with parents=False to avoid creating multiple levels of directories unintentionally. + """ + Path(filename).parent.mkdir(parents=False, exist_ok=True) + + def save_contour(self, output_path: Path | str) -> None: + """Save the contours to a CSV file. + + Parameters + ---------- + output_path : Path, str + Path to save the CSV file. + + Notes + ----- + The function saves contours in CSV format with: + - Header: slice_idx,x,y. + - Special lines indicating new contours with endpoint indices. + - Each point gets its own row with slice index and coordinates. + """ + self.__make_parent_folder(output_path) + logger.info(f"Saving contours to CSV file: {output_path}") + with open(output_path, "w") as f: + + f.write( + f"New contour, anterior_endpoint_idx={self.endpoint_idxs[0]}, " + f"posterior_endpoint_idx={self.endpoint_idxs[1]}\n" + ) + f.write("x,y\n") + for point in self.points: + f.write(f"{point[0]},{point[1]}\n") + + @classmethod + def from_contour_file( + cls: type[Self], + input_path: str | Path, + thickness_values_path: str | Path, + z_position: float = 0.0, + ) -> Self: + """Load contour from a CSV file. + + Parameters + ---------- + input_path : str, Path + Path to the CSV file containing the contours. + thickness_values_path : str, Path + Path to the CSV file containing the thickness_values. + z_position : float, default=0.0 + The distance to the midslice (in fsaverage space). + + Raises + ------ + ValueError + If the file format doesn't match expected structure. + + Notes + ----- + The function: + 1. Reads CSV file with format matching save_contours output. + 2. Processes special lines for endpoint indices. + 3. Reconstructs contours and endpoint indices for each slice. + 4. Converts lists to fixed-size arrays with None padding. + """ + current_points = [] + endpoint_idxs = [] + + with open(input_path) as f: + header = next(f).strip() + # Parse endpoint indices from header + anterior_match = re.search(r'anterior_endpoint_idx=(\d+)', header) + posterior_match = re.search(r'posterior_endpoint_idx=(\d+)', header) + assert anterior_match and posterior_match, "Header does not contain endpoint indices" + + anterior_idx = int(anterior_match.group(1)) + posterior_idx = int(posterior_match.group(1)) + endpoint_idxs = (anterior_idx, posterior_idx) + + # Skip column names + next(f) + + for line in f: + x, y = line.strip().split(",") + current_points.append([float(x), float(y)]) + contour = np.array(current_points) + if thickness_values_path: + thickness_values = cls._load_thickness_values(contour, None, thickness_values_path) + else: + thickness_values = None + return CCContour(contour, thickness_values, endpoint_idxs, z_position=z_position) + + def save_thickness_values(self, output_path: Path | str) -> None: + """Save thickness values to a CSV file. + + Parameters + ---------- + output_path : Path, str + Path to save the CSV file. + + Notes + ----- + The function saves thickness values in CSV format with: + - Header: thickness. + - Each thickness value gets its own row with slice index. + - Skips slices with no thickness values. + """ + self.__make_parent_folder(output_path) + logger.info(f"Saving thickness data to CSV file: {output_path}") + with open(output_path, "w") as f: + f.write("thickness\n") + for value in self.thickness_values: + f.write(f"{value}\n") + + def load_thickness_values( + self, + input_path: str | Path, + ) -> None: + """Load thickness values from a CSV file. + + Parameters + ---------- + input_path : Path, str + Path to the CSV file containing thickness values. + + Raises + ------ + ValueError + If number of thickness values doesn't match measurement points + or if number of slices is inconsistent. + """ + self.thickness_values = self._load_thickness_values(self.points, self.original_thickness_vertices, input_path) + + @classmethod + def _load_thickness_values( + cls, + contour: Points2dType, + original_thickness_vertices: np.ndarray[tuple[int], np.dtype[np.signedinteger]] | None, + input_path: str | Path, + ) -> np.ndarray[tuple[int], np.dtype[np.float_]]: + """See load_thickness_values. + + Ignore shape of thickness values if original_thickness_vertices is None. + + Notes + ----- + The function: + 1. Reads thickness values from CSV file. + 2. Groups values by slice index. + 3. Optionally associates values with specific vertices. + 4. Handles both full contour and profile measurements. + """ + data = np.loadtxt(input_path, delimiter=",", skiprows=1) + if data.ndim == 0: + values = np.array([float(data)]) + elif data.ndim == 1: + values = data.astype(float) + else: + raise ValueError("Thickness values file must contain a single column") + + if len(values) == len(contour): + # Perfect match - use values directly + new_values = values + elif original_thickness_vertices is None: + # No original vertices specified, use values as-is (may differ in length) + new_values = values + elif np.sum(~np.isnan(values)) == len(original_thickness_vertices): + # Values match the number of measurement points, map them to the contour + new_values = np.full(len(contour), np.nan) + new_values[original_thickness_vertices] = values[~np.isnan(values)] + else: + raise ValueError( + f"Number of thickness values {len(values)} does not match number of points in the contour " + f"{len(contour)} and current number of measurement points {len(original_thickness_vertices)} does " + f"not match the number of set thickness values {np.sum(~np.isnan(values))}." + ) + + return new_values + + @classmethod + def from_mask_and_acpc( + cls: type[Self], + cc_mask: Mask2d, + ac_2d: Vector2d, + pc_2d: Vector2d, + slice_vox2ras: AffineMatrix4x4, + contour_smoothing: int = 5 + ) -> Self: + """Extracts the contour of the CC using marching squares, smooth and transform to RAS coordinates. + + Parameters + ---------- + cc_mask : np.ndarray of shape (H, W) and type bool + Binary mask of the corpus callosum. + ac_2d : np.ndarray of shape (2,) and type float + 2D voxel coordinates of the anterior commissure. + pc_2d : np.ndarray of shape (2,) and type float + 2D voxel coordinates of the posterior commissure. + slice_vox2ras : AffineMatrix4x4 + Transformation matrix from slice-voxel space to RAS-coordinates. + contour_smoothing : int, default=5 + Window size for contour smoothing. + + Returns + ------- + contour : CCContour + The contour object. + + Notes + ----- + Expects LIA orientation. + """ + import skimage.measure + from nibabel.affines import apply_affine + + _contour: Points2dType = skimage.measure.find_contours(cc_mask, level=0.5)[0] + + # remove last, duplicate point + _contour = _contour[:-1] + polygon = lapy.Polygon(np.concatenate([np.zeros_like(_contour[:, :1]), _contour], axis=1), closed=True) + polygon.smooth_laplace(n=contour_smoothing, inplace=True) + polygon.resample(700, inplace=True) + contour_ras = apply_affine(slice_vox2ras, polygon.points) + + ac_pc_3d = np.concatenate([np.zeros((2, 1), like=ac_2d), np.stack([ac_2d, pc_2d], axis=0)], axis=1) # (2, 3) + ac_ras, pc_ras = apply_affine(slice_vox2ras, ac_pc_3d) + endpoint_idx = find_cc_endpoints(contour_ras[:, 1:].T, ac_ras[1:], pc_ras[1:]) + + return cls(contour_ras[:, 1:], None, endpoint_idx, z_position=slice_vox2ras[0, 3]) + + +def calculate_volume(contours: list[CCContour], width: float = 5.0) -> float: + """Calculate the volume of the corpus callosum. + + This method calculates the volume of a slab of the CC centered on the midplane. + It multiplies the area of each cross-sectional slice by the width it + represents within the slab. It assumes equally spaced contours centered + around the midplane (z=0). + + Parameters + ---------- + width : float, default=5.0 + The width of the slab centered on the midplane to calculate the volume for (in mm). + + Returns + ------- + float + The volume of the CC in cubic millimeters. + """ + if len(contours) < 2: + return 0.0 + + # Group vertices by their LR coordinate (column 0 as created by from_contours) + z_coords = [contour.z_position for contour in contours] + areas = [contour.area for contour in contours] + + contour_widths = np.diff(z_coords) + + # check that all widths are the same + if not np.allclose(contour_widths, contour_widths[0]): + raise ValueError("Contours must be equally spaced to calculate CC volume") + + contour_width_mm = abs(contour_widths[0]) + + # Define the slab boundaries centered on the midplane (z=0) + z_min, z_max = -width / 2.0, width / 2.0 + + volume = 0.0 + for i, z in enumerate(z_coords): + # Each contour represents a slab of contour_width_mm + # centered at its z position. + start = z - contour_width_mm / 2.0 + end = z + contour_width_mm / 2.0 + + # Intersection of [start, end] and [z_min, z_max] + effective_start = max(start, z_min) + effective_end = min(end, z_max) + + effective_width = max(0.0, effective_end - effective_start) + volume += areas[i] * effective_width + + return volume diff --git a/CorpusCallosum/shape/curvature.py b/CorpusCallosum/shape/curvature.py new file mode 100644 index 000000000..214ad1436 --- /dev/null +++ b/CorpusCallosum/shape/curvature.py @@ -0,0 +1,129 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from CorpusCallosum.utils.types import ContourList, Points2dType + + +def compute_curvature(path: Points2dType) -> np.ndarray[tuple[int], np.dtype[np.float_]]: + """Compute curvature by computing edge angles. + + Parameters + ---------- + path : np.ndarray + Array of shape (N, 2) containing path coordinates. + + Returns + ------- + np.ndarray + Array of angle differences between consecutive edges. + """ + # compute curvature by computing edge angles + edges = np.diff(path, axis=0) + angles = np.arctan2(edges[:, 1], edges[:, 0]) + # compute angle differences between consecutive edges + angle_diffs = np.diff(angles) + # wrap angles to [-pi, pi] + angle_diffs = np.mod(angle_diffs + np.pi, 2 * np.pi) - np.pi + return angle_diffs + + +def compute_mean_curvature(path: Points2dType) -> float: + """Compute mean absolute curvature of a path in degrees. + + Parameters + ---------- + path : np.ndarray + Array of shape (N, 2) containing path coordinates. + + Returns + ------- + float + Mean absolute curvature of the path in degrees. + """ + curvature = compute_curvature(path) + if len(curvature) == 0: + return 0.0 + return np.mean(np.abs(np.degrees(curvature))).item() + + +def calculate_curvature_metrics( + midline: Points2dType, + split_points: np.ndarray | None = None, + split_contours: ContourList | None = None, +) -> tuple[float, float, np.ndarray]: + """ + Calculate curvature metrics for the CC midline, including overall mean, + body (central 65%), and subsegment curvatures. + + Parameters + ---------- + midline : Points2dType + Equidistant points along the midline. + split_points : np.ndarray, optional + Points on the midline where it was split (for orthogonal subdivision). + split_contours : ContourList, optional + List of split contours (for other subdivision methods). + + Returns + ------- + mean_curvature : float + Overall mean curvature. + curvature_body : float + Mean curvature of the central 65% of the midline. + curvature_subsegments : np.ndarray + Mean curvature for each subsegment. + """ + mean_curvature = compute_mean_curvature(midline) + + num_midline_points = len(midline) + # central 65% means we remove 17.5% from each end + start_idx_body = int(num_midline_points * 0.175) + end_idx_body = int(num_midline_points * 0.825) + curvature_body = compute_mean_curvature(midline[start_idx_body:end_idx_body]) + + # Find split indices on the midline for subsegment curvature + split_indices_midline = [0] + if split_points is not None: + for sp in split_points: + idx = np.argmin(np.linalg.norm(midline - sp, axis=1)) + split_indices_midline.append(idx) + elif split_contours is not None: + from CorpusCallosum.shape.subsegment_contour import get_unique_contour_points + unique_points = get_unique_contour_points(split_contours) + for line_pts in unique_points[1:]: + if len(line_pts) == 2: + # find where this line crosses the midline + # use the average of the two points and find closest point on midline + mid_pt = np.mean(line_pts, axis=0) + idx = np.argmin(np.linalg.norm(midline - mid_pt, axis=1)) + split_indices_midline.append(idx) + + split_indices_midline.append(len(midline) - 1) + split_indices_midline.sort() + + _curvature_subsegments = [] + for i in range(len(split_indices_midline) - 1): + s_idx = split_indices_midline[i] + e_idx = split_indices_midline[i + 1] + if e_idx - s_idx >= 2: # need at least 3 points for curvature + curv = compute_mean_curvature(midline[s_idx : e_idx + 1]) + else: + curv = 0.0 + _curvature_subsegments.append(curv) + curvature_subsegments = np.asarray(_curvature_subsegments) + + return mean_curvature, curvature_body, curvature_subsegments + diff --git a/CorpusCallosum/shape/endpoint_heuristic.py b/CorpusCallosum/shape/endpoint_heuristic.py new file mode 100644 index 000000000..09e8e128e --- /dev/null +++ b/CorpusCallosum/shape/endpoint_heuristic.py @@ -0,0 +1,229 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import skimage.measure +from scipy.ndimage import label + +from CorpusCallosum.utils.types import Points2dType, Polygon2dType +from FastSurferCNN.utils import Image2d, Mask2d, Vector2d + + +def smooth_contour(x: np.ndarray, y: np.ndarray, window_size: int) -> tuple[np.ndarray, np.ndarray]: + """Smooth a contour using a moving average filter. + + Parameters + ---------- + x : np.ndarray + X-coordinates of the contour points. + y : np.ndarray + Y-coordinates of the contour points. + window_size : int + Size of the smoothing window. Must be odd and > 2. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + Smoothed x and y coordinates of the contour. + """ + # Ensure window_size is an integer + window_size = int(window_size) + + if window_size // 2 == 0: + raise ValueError(f"Smoothing window size of {window_size} is too small") + + # Ensure the window size is odd + if window_size % 2 == 0: + window_size += 1 + + # Create a padded version of the arrays to handle the edges + x_padded = np.pad(x, (window_size // 2, window_size // 2), mode="wrap") + y_padded = np.pad(y, (window_size // 2, window_size // 2), mode="wrap") + + # Apply moving average + x_smoothed = np.zeros_like(x) + y_smoothed = np.zeros_like(y) + + for i in range(len(x)): + x_smoothed[i] = np.mean(x_padded[i : i + window_size]) + y_smoothed[i] = np.mean(y_padded[i : i + window_size]) + + return x_smoothed, y_smoothed + + +def connect_diagonally_connected_components(cc_mask: Image2d) -> Image2d: + """Connect diagonally connected components in the CC mask. + + Parameters + ---------- + cc_mask : np.ndarray + Binary mask of the corpus callosum. + + Notes + ----- + Modifies the input mask in-place to connect diagonally connected components. + """ + + # Create padded mask to handle boundary conditions + padded_mask = np.pad(cc_mask, pad_width=1, mode='constant', constant_values=0) + + # Get center pixels and diagonal neighbors + center = padded_mask[1:-1, 1:-1] + + # Direct neighbors (4-connectivity) + left = padded_mask[1:-1, :-2] # left + right = padded_mask[1:-1, 2:] # right + up = padded_mask[:-2, 1:-1] # up + down = padded_mask[2:, 1:-1] # down + + # Diagonal neighbors + up_left = padded_mask[:-2, :-2] # up-left + up_right = padded_mask[:-2, 2:] # up-right + down_left = padded_mask[2:, :-2] # down-left + down_right = padded_mask[2:, 2:] # down-right + + potential_diagonal_gaps = (center == 0) & ( + ((up_left > 0) & ((right > 0) | (down > 0))) | + ((up_right > 0) & ((left > 0) | (down > 0))) | + ((down_left > 0) & ((right > 0) | (up > 0))) | + ((down_right > 0) & ((left > 0) | (up > 0))) + ) + + # Get connected components before filling using 4-connectivity + # This way, diagonal-only connections are treated as separate components + structure_4conn = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) + _, num_components_before = label(cc_mask, structure=structure_4conn) + + # For each potential gap, check if filling it would reduce the number of components + connects_diagonals = np.zeros_like(potential_diagonal_gaps) + gap_positions = np.where(potential_diagonal_gaps) + + if len(gap_positions[0]) > 0: + test_mask = cc_mask.copy() + # Fill all gap voxels, that by themselves would connect 2 components + for i, j in zip(gap_positions[0], gap_positions[1], strict=True): + # Temporarily fill this gap + test_mask[i, j] = 1 + # Check connected components after filling, this is relatively slow... + _, num_components_after = label(test_mask, structure=structure_4conn) + # Only fill if it actually connects previously disconnected components + if num_components_after < num_components_before: + connects_diagonals[i, j] = True + # Revert temporary fill + test_mask[i, j] = cc_mask[i, j] + + # Fill the identified diagonal gaps that actually improve connectivity + return np.where(connects_diagonals, 1, cc_mask) + + +def extract_cc_contour(cc_mask: Mask2d, contour_smoothing: int = 5) -> Polygon2dType: + """Extract the contour of the CC from the mask using a marching squares approach. + + Parameters + ---------- + cc_mask : np.ndarray + Binary mask of the corpus callosum. + contour_smoothing : int, default=5 + Window size for contour smoothing. + + Returns + ------- + lapy.Polygon + A lapy Polygon object with a closed polygon contour. + """ + cc_mask = connect_diagonally_connected_components(cc_mask) + + contour = skimage.measure.find_contours(cc_mask, level=0.5)[0].T + contour = np.array(smooth_contour(contour[0], contour[1], contour_smoothing)) + + return contour + + +def find_cc_endpoints( + contour: Points2dType, + ac_2d: Vector2d, + pc_2d: Vector2d, + plot: bool = False, +): + """Extracts the contour of the CC, rotates to AC-PC alignment, and determines closest points of CC to AC and PC. + + Parameters + ---------- + contour : np.ndarray of shape (2, N) + Points of the CC contour in AS (millimeter). + ac_2d : np.ndarray of shape (2,) and type float + 2D AS coordinates of the anterior commissure in millimeter. + pc_2d : np.ndarray of shape (2,) and type float + 2D AS coordinates of the posterior commissure in millimeter. + + Returns + ------- + anterior_posterior_point_indices : pair of ints + Indices of anterior and posterior points in the contour. + + Notes + ----- + Expects AS orientation of contour, ac_2d, and pc_2d. + """ + if contour.shape[0] != 2: + raise ValueError(f"contour must have shape (2, N), got {contour.shape}") + if any(p2d.shape != (2,) for p2d in (ac_2d, pc_2d)): + raise ValueError(f"ac_2d and pc_2d must have shape (2,), got {ac_2d.shape} and {pc_2d.shape}") + + # Calculate angle between AC-PC line and horizontal using numpy + ac_pc_vector = pc_2d - ac_2d + horizontal_vector = np.array([-20, 0]) + # Calculate angle using dot product formula: cos(theta) = (a·b)/(|a||b|) + dot_product = np.dot(ac_pc_vector, horizontal_vector) + norms = np.linalg.norm(ac_pc_vector) * np.linalg.norm(horizontal_vector) + # The sign of theta is the inverse of ac_pc_vector [ X ] + theta = np.sign(ac_pc_vector[0]) * np.arccos(dot_product / norms) + + rot_matrix_inv = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + # move posterior commisure 5 mm posterior, 10 mm inferior + as_offset_pc = np.array([-5, -10], dtype=float) + posterior_anchor_2d = pc_2d.astype(float) + rot_matrix_inv @ as_offset_pc + # move anterior commisure 5 mm anterior + as_offset_ac = np.array([5, 0], dtype=float) + anterior_anchor_2d = ac_2d.astype(float) + rot_matrix_inv @ as_offset_ac + + # Find the endpoints of the CC shape relative to AC and PC coordinates + # find point in contour closest to AC + ac_startpoint_idx = np.argmin(np.linalg.norm(contour - anterior_anchor_2d[:, None], axis=0)) + # find point in contour closest to PC + pc_startpoint_idx = np.argmin(np.linalg.norm(contour - posterior_anchor_2d[:, None], axis=0)) + + if plot: # interactive debug plot of contour, ac, pc and endpoints + import matplotlib + import matplotlib.pyplot as plt + curr_backend = matplotlib.get_backend() + plt.switch_backend("qtagg") + plt.figure(figsize=(10, 8)) + plt.plot(contour[0, :], contour[1, :], 'b-', label='CC Contour', linewidth=2) + plt.plot(ac_2d[0], ac_2d[1], 'go', markersize=10, label='AC') + plt.plot(pc_2d[0], pc_2d[1], 'ro', markersize=10, label='PC') + plt.plot(anterior_anchor_2d[0], anterior_anchor_2d[1], 'g^', markersize=10, label='Anterior Anchor') + plt.plot(posterior_anchor_2d[0], posterior_anchor_2d[1], 'r^', markersize=10, label='Posterior Anchor') + plt.plot(contour[0, ac_startpoint_idx], contour[1, ac_startpoint_idx], 'g*', markersize=15, label='AC Endpoint') + plt.plot(contour[0, pc_startpoint_idx], contour[1, pc_startpoint_idx], 'r*', markersize=15, label='PC Endpoint') + plt.xlabel('A-S (mm)') + plt.ylabel('I-S (mm)') + plt.title('CC Contour with Endpoints') + plt.legend() + plt.axis('equal') + plt.grid(True, alpha=0.3) + plt.show() + plt.switch_backend(curr_backend) + + return ac_startpoint_idx, pc_startpoint_idx diff --git a/CorpusCallosum/shape/mesh.py b/CorpusCallosum/shape/mesh.py new file mode 100644 index 000000000..f3e5f22de --- /dev/null +++ b/CorpusCallosum/shape/mesh.py @@ -0,0 +1,819 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path +from typing import TypeVar + +import lapy +import nibabel as nib +import numpy as np +import plotly.graph_objects as go +from lapy import TriaMesh +from plotly.io import write_html as plotly_write_html +from scipy.ndimage import gaussian_filter1d + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.thickness import make_mesh_from_contour +from FastSurferCNN.utils import AffineMatrix4x4, nibabelImage +from FastSurferCNN.utils.common import suppress_stdout, update_docstring + +try: + from pyrr import Matrix44 + HAS_PYRR = True +except ImportError: + HAS_PYRR = False + class Matrix44(np.ndarray): + pass + +logger = logging.get_logger(__name__) + + + +def _create_cap( + points: np.ndarray, + contour: CCContour, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Create a cap mesh for one end of the corpus callosum. + + Parameters + ---------- + points : np.ndarray + Array of shape (N, 2) containing mesh points + trias : np.ndarray + Array of shape (M, 3) containing triangle indices + contour : CCContour + CCContour object to create cap for + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + - level_vertices : Array of vertices for the cap mesh + - level_faces : Array of face indices for the cap mesh + - level_colors : Array of thickness values for each vertex + + Notes + ----- + The function: + 1. Creates level paths using _create_levelpaths + 2. Resamples level paths to fixed number of points + 3. Creates triangles between consecutive level paths + 4. Smooths thickness values for visualization + """ + levelpaths, thickness_values, _, _, _, _, _ = contour.create_levelpaths(num_points=len(points), inplace=False) + + # Create mesh from level paths + level_vertices = [] + level_faces = [] + level_colors = [] + vertex_counter = 0 + sorted_thickness_values = np.array(thickness_values) + + # smooth thickness values + for _ in range(3): + sorted_thickness_values = gaussian_filter1d(sorted_thickness_values, sigma=5) + + NUM_LEVELPOINTS = 50 + + assert len(sorted_thickness_values) == len(levelpaths) + + # TODO: handle gap between first/last levelpath and contour + for idx, levelpath1 in enumerate(levelpaths): + levelpath1 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpath1, NUM_LEVELPOINTS) + level_vertices.append(levelpath1) + level_colors.append(np.full((len(levelpath1)), sorted_thickness_values[idx])) + if idx + 1 < len(levelpaths): + levelpath2 = lapy.TriaMesh._TriaMesh__iterative_resample_polygon(levelpaths[idx + 1], NUM_LEVELPOINTS) + + # Create faces between the two paths by connecting vertices + faces_between = [] + i, j = 0, 0 + + while i < len(levelpath1) - 1 and j < len(levelpath2) - 1: + faces_between.append([i, i + 1, len(levelpath1) + j]) + faces_between.append([i + 1, len(levelpath1) + j + 1, len(levelpath1) + j]) + + i += 1 + j += 1 + + while i < len(levelpath1) - 1: + faces_between.append([i, i + 1, len(levelpath1) + j]) + i += 1 + + while j < len(levelpath2) - 1: + faces_between.append([i, len(levelpath1) + j + 1, len(levelpath1) + j]) + j += 1 + + if faces_between: + faces_between = np.array(faces_between) + level_faces.append(faces_between + vertex_counter) + + vertex_counter += len(levelpath1) + + # Convert to numpy arrays + level_vertices = np.vstack(level_vertices) + # Add z-coordinate (column 0) to make vertices 3D + level_vertices = np.hstack([np.full((len(level_vertices), 1), contour.z_position), level_vertices]) + level_faces = np.vstack(level_faces) + level_colors = np.concatenate(level_colors) + + return level_vertices, level_faces, level_colors + + +def make_triangles_between_contours(contour1: np.ndarray, contour2: np.ndarray) -> np.ndarray: + """Create a triangular mesh between two contours using a robust method. + + Parameters + ---------- + contour1 : np.ndarray + First contour points of shape (N, 2). + contour2 : np.ndarray + Second contour points of shape (M, 2). + + Returns + ------- + np.ndarray + Array of triangle indices of shape (K, 3) where K is the number of triangles. + + Notes + ----- + The function: + 1. Finds closest point on contour2 to first point of contour1 + 2. Creates triangles by connecting corresponding points + 3. Handles contours with different numbers of points + 4. Creates two triangles to form a quad between each pair of points + """ + start_idx_c1 = 0 + # get closest point on contour2 to contour1[0] + start_idx_c2 = np.argmin(np.linalg.norm(contour2 - contour1[0], axis=1)) + + triangles = [] + n1 = len(contour1) + n2 = len(contour2) + + for i in range(n1): + # Current and next indices for contour1 + c1_curr = (start_idx_c1 + i) % n1 + c1_next = (start_idx_c1 + i + 1) % n1 + + # Current and next indices for contour2, offset by n1 to account for vertex stacking + c2_curr = ((start_idx_c2 + i) % n2) + n1 + c2_next = ((start_idx_c2 + i + 1) % n2) + n1 + + # Create two triangles to form a quad between the contours + triangles.append([c1_curr, c2_curr, c1_next]) + triangles.append([c2_curr, c2_next, c1_next]) + + return np.array(triangles) + + +Self = TypeVar('Self', bound='type[CCMesh]') + + +class CCMesh(lapy.TriaMesh): + """A class for representing and manipulating corpus callosum (CC) meshes. + + This class extends lapy.TriaMesh to provide specialized functionality for working with + corpus callosum meshes, including contour management, thickness measurements, and + visualization capabilities. + + The mesh can be constructed from a series of 2D contours representing slices of the + corpus callosum, with optional thickness measurements at various points along these + contours. + + Attributes + ---------- + v : np.ndarray + Vertex coordinates of the mesh. + t : np.ndarray + Triangle indices of the mesh. + mesh_vertex_colors : np.ndarray + Vertex values for each vertex (CC thickness values) + """ + + def __init__( + self, + vertices: list | np.ndarray, + faces: list | np.ndarray, + vertex_values: list | np.ndarray | None = None, + ): + """Initialize a CC_Mesh object. + + Parameters + ---------- + vertices : list or numpy.ndarray + List of vertex coordinates or array of shape (N, 3). + faces : list or numpy.ndarray + List of face indices or array of shape (M, 3). + vertex_values : list or numpy.ndarray, optional + Vertex values for each vertex (CC thickness values) + """ + super().__init__(np.vstack(vertices), np.vstack(faces)) + self.mesh_vertex_colors = vertex_values + + + def plot_mesh( + self, + output_path: Path | str | None = None, + colormap: str = "red_to_yellow", + thickness_overlay: bool = True, + show_grid: bool = False, + color_range: tuple[float, float] | None = None, + legend: str = "", + threshold: tuple[float, float] | None = None, + ): + """Plot the mesh using Plotly for better performance and interactivity. + + Creates an interactive 3D visualization of the mesh with optional features like + thickness overlay, contour display, and grid visualization. + + Parameters + ---------- + output_path : Path, str, optional + Path to save the plot. If None, displays the plot interactively. + colormap : str, optional + Which colormap to use, by default "red_to_yellow". + Options: + - "red_to_blue": Red -> Orange -> Grey -> Light Blue -> Blue + - "red_to_yellow": Red -> Yellow -> Light Blue -> Blue + - "yellow_to_red": Yellow -> Light Blue -> Blue -> Red + - "blue_to_red": Blue -> Light Blue -> Grey -> Orange -> Red + thickness_overlay : bool, optional + Whether to overlay thickness values on the mesh, by default True. + show_contours : bool, optional + Whether to show the contours, by default False. + show_grid : bool, optional + Whether to show the grid, by default False. + color_range : tuple[float, float], optional + Fixed range (min, max) for the colorbar, by default None. + show_mesh_edges : bool, optional + Whether to show the mesh edges, by default False. + legend : str, optional + Legend text for the colorbar, by default "". + threshold : tuple[float, float], optional + Values between these thresholds will be shown in grey, by default None. + + Notes + ----- + The plot can be saved to an HTML file or displayed in a web browser. + """ + assert self.v is not None and self.t is not None, "Mesh has not been created yet" + + if len(self.v) == 0: + logger.warning("Warning: No vertices in mesh to plot") + return + + if len(self.t) == 0: + logger.warning("Warning: No faces in mesh to plot") + return + + # Define available colormaps + colormaps = { + "red_to_blue": [ + [0.0, "rgb(255,0,0)"], # Bright red + [0.25, "rgb(255,165,0)"], # Light orange + [0.5, "rgb(150,150,150)"], # Dark grey for middle + [0.75, "rgb(173,216,230)"], # Light blue + [1.0, "rgb(0,0,255)"], # Bright blue + ], + "blue_to_red": [ + [0.0, "rgb(0,0,255)"], # Bright blue + [0.25, "rgb(173,216,230)"], # Light blue + [0.5, "rgb(150,150,150)"], # Dark grey for middle + [0.75, "rgb(255,165,0)"], # Light orange + [1.0, "rgb(255,0,0)"], # Bright red + ], + "red_to_yellow": [ + [0.0, "rgb(255,0,0)"], # Bright red + [0.33, "rgb(255,85,0)"], # Red-orange + [0.66, "rgb(255,170,0)"], # Orange + [1.0, "rgb(255,255,0)"], # Yellow + ], + "yellow_to_red": [ + [0.0, "rgb(255,255,0)"], # Yellow + [0.33, "rgb(255,170,0)"], # Orange + [0.66, "rgb(255,85,0)"], # Red-orange + [1.0, "rgb(255,0,0)"], # Bright red + ], + } + + # Select the colormap + if colormap not in colormaps: + logger.warning(f"Warning: Unknown colormap '{colormap}'. Using 'red_to_blue' instead.") + colormap = "red_to_blue" + + selected_colormap = colormaps[colormap] + + # If threshold is provided, modify the colormap to include grey region + if threshold is not None and thickness_overlay and hasattr(self, "mesh_vertex_colors"): + data_min = np.min(self.mesh_vertex_colors) if color_range is None else color_range[0] + data_max = np.max(self.mesh_vertex_colors) if color_range is None else color_range[1] + data_range = data_max - data_min + + # Calculate normalized threshold positions + thresh_low = (threshold[0] - data_min) / data_range + thresh_high = (threshold[1] - data_min) / data_range + + # Ensure thresholds are within [0,1] + thresh_low = max(0, min(1, thresh_low)) + thresh_high = max(0, min(1, thresh_high)) + + # Create new colormap with grey threshold region + grey_color = "rgb(150,150,150)" # Medium grey + new_colormap = [] + + # Add colors before threshold with adjusted positions + if thresh_low > 0: + for pos, color in selected_colormap: + if pos < 1: # Only use positions less than 1 + new_pos = pos * thresh_low + new_colormap.append([new_pos, color]) + + # Add threshold boundaries with grey + new_colormap.extend([[thresh_low, grey_color], [thresh_high, grey_color]]) + + # Add colors after threshold with adjusted positions + if thresh_high < 1: + remaining_range = 1 - thresh_high + for pos, color in selected_colormap: + if pos > 0: # Only use positions greater than 0 + new_pos = thresh_high + pos * remaining_range + if new_pos <= 1: # Ensure we don't exceed 1 + new_colormap.append([new_pos, color]) + + selected_colormap = new_colormap + + # Calculate data ranges and center + xyz_min = self.v.min(axis=0) + xyz_max = self.v.max(axis=0) + xyz_range = xyz_max - xyz_min + max_range = xyz_range.max() + center = (xyz_max + xyz_min) / 2 + + # Create mesh plot + fig = go.Figure() + + # Add the mesh as a surface + mesh_args = { + "x": self.v[:, 0], + "y": self.v[:, 1], + "z": self.v[:, 2], + "i": self.t[:, 0], # First vertex of each triangle + "j": self.t[:, 1], # Second vertex + "k": self.t[:, 2], # Third vertex + "hoverinfo": "skip", + "lighting": dict(ambient=0.9, diffuse=0.1, roughness=0.3), + } + + if thickness_overlay and hasattr(self, "mesh_vertex_colors"): + mesh_args.update( + { + "intensity": self.mesh_vertex_colors, # Add intensity values for colorbar + "showscale": True, + "colorbar": dict( + title=dict( + text=legend, + font=dict(size=35, color="white"), # Increase title font size and make white + side="right", # Place title on right side + ), + len=0.55, # Make colorbar shorter + thickness=35, # Make colorbar wider + tickfont=dict(size=30, color="white"), # Increase tick font size and make white + tickformat=".1f", # Show one decimal place + ), + "opacity": 1, + "colorscale": selected_colormap, + } + ) + + # Set the colorbar range + if color_range is not None: + mesh_args["cmin"] = color_range[0] + mesh_args["cmax"] = color_range[1] + else: + # Use data range if no explicit range is provided + mesh_args["cmin"] = np.min(self.mesh_vertex_colors) + mesh_args["cmax"] = np.max(self.mesh_vertex_colors) + else: + mesh_args["color"] = "lightsteelblue" + + fig.add_trace(go.Mesh3d(**mesh_args)) + + # Calculate axis ranges to maintain equal aspect ratio + ranges = [] + for i in range(3): + axis_range = [center[i] - max_range / 2, center[i] + max_range / 2] + ranges.append(axis_range) + + # Configure axes and grid visibility + axis_config = dict( + showgrid=show_grid, + showline=show_grid, + zeroline=show_grid, + showbackground=show_grid, + showticklabels=show_grid, + gridcolor="white", + tickfont=dict(color="white"), + title=dict(font=dict(color="white")), + ) + + fig.update_layout( + scene=dict( + xaxis=dict(range=ranges[0], **{**axis_config, "title": "LR" if show_grid else ""}), + yaxis=dict(range=ranges[1], **{**axis_config, "title": "AP" if show_grid else ""}), + zaxis=dict(range=ranges[2], **{**axis_config, "title": "SI" if show_grid else ""}), + camera=dict(eye=dict(x=1.5, y=1.5, z=1), up=dict(x=0, y=0, z=1)), + aspectmode="cube", # Force equal aspect ratio + aspectratio=dict(x=1, y=1, z=1), + bgcolor="black", + dragmode="orbit", # Enable orbital rotation by default + ), + showlegend=False, + margin=dict(l=0, r=100, t=0, b=0), # Increased right margin for colorbar + paper_bgcolor="black", + plot_bgcolor="black", + ) + + if output_path is not None: + self.__make_parent_folder(output_path) + plotly_write_html(fig, output_path, include_plotlyjs="cdn") # Save as interactive HTML + else: + # For non-interactive display, save to a temporary HTML and open in browser + import tempfile + import webbrowser + + temp_path = Path(tempfile.gettempdir()) / "cc_mesh_plot.html" + plotly_write_html(fig, temp_path, include_plotlyjs="cdn") + webbrowser.open(f"file://{temp_path}") + + + @staticmethod + def __create_cc_viewmat() -> "Matrix44": + """Create the view matrix for a nice view of the corpus callosum. + + Returns + ------- + Matrix44 + 4x4 view matrix that provides a standard view of the corpus callosum (from pyrr). + + Notes + ----- + The function: + 1. Creates a base view matrix looking from the left with top up + 2. Applies a series of rotations: + - -10 degrees around x-axis + - 35 degrees around y-axis + - -8 degrees around z-axis + 3. Adds a small translation for better centering + """ + + if not HAS_PYRR: + raise ImportError("Pyrr not installed, install pyrr with `pip install pyrr`.") + + viewLeft = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) # left w top up // right + transl = Matrix44.from_translation((0, 0, 0.4)) + viewmat = transl * viewLeft + + # rotate 10 degrees around x axis + rot = Matrix44.from_x_rotation(np.deg2rad(-10)) + viewmat = viewmat * rot + + # rotate 35 degrees around y axis + rot = Matrix44.from_y_rotation(np.deg2rad(35)) + viewmat = viewmat * rot + + # rotate 10 degrees around z axis + rot = Matrix44.from_z_rotation(np.deg2rad(-8)) + viewmat = viewmat * rot + + return viewmat + + def snap_cc_picture( + self, + output_path: Path | str, + fssurf_file: Path | str | None = None, + overlay_file: Path | str | None = None, + ref_image: Path | str | nibabelImage | None = None, + ) -> None: + """Snap a picture of the corpus callosum mesh. + + Parameters + ---------- + output_path : Path, str + Path where to save the snapshot image. + fssurf_file : Path, str, optional + Path to a FreeSurfer surface file to use for the snapshot. + If None, the mesh is saved to a temporary file. + overlay_file : Path, str, optional + Path to a FreeSurfer overlay file to use for the snapshot. + If None, the mesh is saved to a temporary file. + ref_image : Path, str, nibabelImage, optional + Path to reference image to use for tkr creation. If None, ignores the file for saving. + + Raises + ------ + Warning + If the mesh has no faces and cannot create a snapshot. + + Notes + ----- + The function: + 1. Creates temporary files for mesh and overlay data if needed. + 2. Uses whippersnappy to create a snapshot with: + - Custom view matrix for standard orientation. + - Ambient lighting and colorbar settings. + - Thickness overlay if available. + 3. Cleans up temporary files after use. + """ + try: + from whippersnappy.core import snap1 + except ImportError: + # whippersnappy not installed + raise RuntimeError( + "The snap_cc_picture method of CCMesh requires whippersnappy, but whippersnappy was not found. " + "Please install whippersnappy!" + ) from None + self.__make_parent_folder(output_path) + # Skip snapshot if there are no faces + if len(self.t) == 0: + logger.warning("Cannot create snapshot - no faces in mesh") + return + + # create temp file + if fssurf_file: + fssurf_file = Path(fssurf_file) + else: + fssurf_file = tempfile.NamedTemporaryFile(suffix=".fssurf", delete=True).name + + ref_image_arg = str(ref_image) if isinstance(ref_image, (Path, str)) else ref_image + self.write_fssurf(fssurf_file, image=ref_image_arg) + + if overlay_file: + overlay_file = Path(overlay_file) + else: + overlay_file = tempfile.NamedTemporaryFile(suffix=".w", delete=True).name + # Write thickness values in FreeSurfer '*.w' overlay format + self.write_morph_data(overlay_file) + + try: + with suppress_stdout(): + snap1( + fssurf_file, + overlaypath=overlay_file, + view=None, + viewmat=self.__create_cc_viewmat(), + width=3 * 500, + height=3 * 300, + outpath=output_path, + ambient=0.6, + colorbar_scale=0.5, + colorbar_y=0.88, + colorbar_x=0.19, + brain_scale=2.1, + fthresh=0, + caption="Corpus Callosum thickness (mm)", + caption_y=0.85, + caption_x=0.17, + caption_scale=0.5, + ) + except Exception as e: + raise e from None + + if fssurf_file and hasattr(fssurf_file, "close"): + fssurf_file.close() + if overlay_file and hasattr(overlay_file, "close"): + overlay_file.close() + + def smooth_(self, iterations: int = 1) -> None: + """Smooth the mesh while preserving the z-coordinates. + + Parameters + ---------- + iterations : int, optional + Number of smoothing iterations, by default 1. + + Notes + ----- + The function: + 1. Stores original z-coordinates. + 2. Applies Laplacian smoothing to x and y coordinates. + 3. Restores original z-coordinates to maintain slice structure. + """ + z_values = self.v[:, 0] + super().smooth_(iterations) + self.v[:, 0] = z_values + + + @staticmethod + def __make_parent_folder(filename: Path | str) -> None: + """Create the parent folder for a file if it doesn't exist. + + Parameters + ---------- + filename : Path, str + Path to the file whose parent folder should be created. + + Notes + ----- + Creates parent directory with parents=False to avoid creating + multiple levels of directories unintentionally. + """ + Path(filename).parent.mkdir(parents=False, exist_ok=True) + + def to_vox_coordinates( + self: Self, + mesh_ras2vox: AffineMatrix4x4, + ) -> Self: + """Convert mesh coordinates to FreeSurfer coordinate system. + + Parameters + ---------- + mesh_ras2vox : AffineMatrix4x4 + Transformation matrix from midplane mesh space (RAS centered on midplane) to voxel coordinates. + + Returns + ------- + CCMesh + A CCMesh object with vertices reoriented to FreeSurfer coordinates. + + Notes + ----- + Mesh coordinates are in ASR (Anterior-Superior-Right) orientation, with the coordinate system origin on + *the* midslice. The function *first* transforms from midslice ASR to LIA vox coordinates. + """ + from copy import copy + new_object = copy(self) + + # tkrRAS = Torig*[C R S 1]' + # Torig: mri_info --vox2ras-tkr orig.mgz + # https://surfer.nmr.mgh.harvard.edu/fswiki/CoordinateSystems + + new_object.v = (mesh_ras2vox[:3, :3] @ self.v.T).T + mesh_ras2vox[None, :3, 3] + return new_object + + @update_docstring(parent_doc=TriaMesh.write_fssurf.__doc__) + def write_fssurf(self, filename: Path | str, image: str | nibabelImage | None = None) -> None: + """{parent_doc} + Also creates parent directory if needed before writing the file. + """ + self.__make_parent_folder(filename) + return super().write_fssurf(filename, image=image) + + def write_morph_data(self, filename: Path | str) -> None: + """Write the thickness values as a FreeSurfer overlay file. + + Parameters + ---------- + filename : Path, str + Path where to save the overlay file. + + Notes + ----- + Creates parent directory if needed before writing the file. + """ + self.__make_parent_folder(filename) + return nib.freesurfer.write_morph_data(filename, self.mesh_vertex_colors) + + @classmethod + def from_contours( + cls: type[Self], + contours: list[CCContour], + lr_center: float = 0, + closed: bool = False, + smooth: int = 0, + ) -> Self: + """Create a surface mesh by triangulating between consecutive contours. + + Parameters + ---------- + contours : list[CCContour] + List of CCContour objects to create mesh from. + lr_center : float, default=0 + Center position in the left-right axis. + closed : bool, default=False + Whether to create a closed mesh by adding caps. + smooth : int, default=0 + Number of smoothing iterations to apply. + + Returns + ------- + CCMesh + The joined CCMesh object. + + Raises + ------ + Warning + If no valid contours are found. + + Notes + ----- + The function: + 1. Filters out None contours. + 2. Calculates z-coordinates for each slice. + 3. Creates triangles between adjacent contours. + 4. Optionally: + - Creates caps at both ends. + - Applies smoothing. + - Colors caps based on thickness values. + """ + # Check that all contours have the same resolution + z_coordinates = np.array([contour.z_position for contour in contours]) + same_z_position = np.isclose(z_coordinates[:, None], z_coordinates[None, :]) + # filter for diagonal and duplicates + unique_same_z_position = np.logical_and(same_z_position, np.tri(z_coordinates.shape[0], k=-1, dtype=bool).T) + if np.any(unique_same_z_position): + raise ValueError( + f"All contours must have different z_positions, but {np.array(np.where(unique_same_z_position)).T} " + f"have similar z_positions." + ) + + # Calculate z coordinates for each slice + # z_coordinates = (np.arange(len(contours)) - len(contours) // 2) * contours[0].resolution + lr_center + + # vertices list with z-coordinates and collect thickness values + vertices = [] + faces = [] + vertex_values_list = [] + vertex_start_indices = [] # Track starting index for each contour + current_index = 0 + previous_contour: CCContour | None = None + + for contour in contours: + vertex_start_indices.append(current_index) + vertices.append(np.hstack([np.full((len(contour.points), 1), contour.z_position), contour.points])) + if contour.thickness_values is not None: + vertex_values_list.append(contour.thickness_values) + + # Check if there's a next valid contour to connect to + if previous_contour is not None: + if len(previous_contour.points) != len(contour.points): + raise ValueError("The number of points of multiple contours must be the same!") + faces_between = make_triangles_between_contours(previous_contour.points, contour.points) + faces.append(faces_between + current_index) + + current_index += len(previous_contour.points) + + previous_contour = contour + + vertex_values = None + if len(vertex_values_list) == len(contours): + vertex_values = np.concatenate(vertex_values_list) + elif len(vertex_values_list) > 0: + logger.warning("Some contours have thickness values while others don't; skipping thickness overlay") + + if smooth > 0: + tmp_mesh = CCMesh(vertices, faces, vertex_values=vertex_values) + tmp_mesh.smooth_(smooth) + vertices = tmp_mesh.v + faces = tmp_mesh.t + vertex_values = tmp_mesh.mesh_vertex_colors + + if closed: + # this functionality is untested and not used + logger.warning("CCMesh.from_contours(closed=True) is untested and likely has errors.") + + # Close the mesh by creating caps on both ends + # Left cap (first slice) + left_side_points, left_side_trias = make_mesh_from_contour(vertices[0][..., 1:]) + left_side_points = np.hstack([np.full((len(left_side_points), 1), z_coordinates[0]), left_side_points]) + + # Right cap (last slice) + right_side_points, right_side_trias = make_mesh_from_contour(vertices[-1][..., 1:]) + right_side_points = np.hstack([np.full((len(right_side_points), 1), z_coordinates[-1]), right_side_points]) + + # color_sides is a legacy visualization option to allow caps to have thickness colors + color_sides = True + if color_sides: + left_side_points, left_side_trias, left_side_colors = _create_cap( + left_side_points, contours[0] + ) + right_side_points, right_side_trias, right_side_colors = _create_cap( + right_side_points, contours[-1] + ) + # reverse right side trias for proper orientation + right_side_trias = right_side_trias[:, ::-1] + + vertex_values = np.concatenate([vertex_values, left_side_colors, right_side_colors]) + + left_side_trias = left_side_trias + current_index + current_index += len(left_side_points) + + right_side_trias = right_side_trias + current_index + current_index += len(right_side_points) + + vertices.extend([left_side_points, right_side_points]) + faces.extend([left_side_trias, right_side_trias]) + + return cls(vertices, faces, vertex_values=vertex_values) diff --git a/CorpusCallosum/shape/metrics.py b/CorpusCallosum/shape/metrics.py new file mode 100644 index 000000000..4003b7b7f --- /dev/null +++ b/CorpusCallosum/shape/metrics.py @@ -0,0 +1,333 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import FastSurferCNN.utils.logging as logging + +logger = logging.get_logger(__name__) + + +# TODO: we could make this more robust by standardizing orientation with AC/PC and smoothing the contour + +def _line_segment_intersection( + line_point: np.ndarray, + line_dir: np.ndarray, + seg_start: np.ndarray, + seg_end: np.ndarray, + tol: float = 1e-10, +) -> np.ndarray | None: + """Compute intersection between an infinite line and a line segment. + + Uses the parametric form: + - Line: P = line_point + t * line_dir + - Segment: Q = seg_start + s * (seg_end - seg_start), where s ∈ [0, 1] + + Parameters + ---------- + line_point : np.ndarray + A point on the infinite line, shape (2,). + line_dir : np.ndarray + Direction vector of the line, shape (2,). + seg_start : np.ndarray + Start point of the segment, shape (2,). + seg_end : np.ndarray + End point of the segment, shape (2,). + tol : float + Tolerance for numerical comparisons. + + Returns + ------- + np.ndarray | None + Intersection point as shape (2,) array, or None if no intersection. + """ + seg_dir = seg_end - seg_start + + # Build the linear system: [line_dir, -seg_dir] @ [t, s].T = seg_start - line_point + # Matrix A = [[line_dir[0], -seg_dir[0]], [line_dir[1], -seg_dir[1]]] + A = np.array([[line_dir[0], -seg_dir[0]], + [line_dir[1], -seg_dir[1]]]) + b = seg_start - line_point + + # Check if lines are parallel (determinant ≈ 0) + det = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0] + if abs(det) < tol: + return None + + # Solve for t and s using Cramer's rule (faster than linalg.solve for 2x2) + t = (b[0] * A[1, 1] - b[1] * A[0, 1]) / det + s = (A[0, 0] * b[1] - A[1, 0] * b[0]) / det + + # Check if intersection is within the segment [0, 1] + if -tol <= s <= 1.0 + tol: + return line_point + t * line_dir + return None + + +def get_intersections( + contour: np.ndarray, start_point: np.ndarray, direction: np.ndarray +) -> np.ndarray: + """Find intersection points between an infinite line and a closed contour. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + start_point : np.ndarray + A point on the line, shape (2,). + direction : np.ndarray + Direction vector of the line, shape (2,). + + Returns + ------- + np.ndarray + Array of shape (M, 2) containing intersection points, sorted along the direction. + """ + start_point = np.asarray(start_point, dtype=float) + direction = np.asarray(direction, dtype=float) + + # Normalize direction + dir_norm = np.linalg.norm(direction) + if dir_norm < 1e-10: + return np.empty((0, 2)) + direction = direction / dir_norm + + n_points = contour.shape[1] + intersections = [] + + # Check intersection with each segment of the closed contour + for i in range(n_points): + seg_start = contour[:, i] + seg_end = contour[:, (i + 1) % n_points] # Wrap around to close the contour + + intersection = _line_segment_intersection( + start_point, direction, seg_start, seg_end + ) + if intersection is not None: + intersections.append(intersection) + + if not intersections: + return np.empty((0, 2)) + + points = np.array(intersections) + + # Remove duplicate points (can occur at contour vertices) + if len(points) > 1: + # Project onto line direction and find unique points + projections = np.dot(points - start_point, direction) + # Sort and remove duplicates within tolerance + sorted_idx = np.argsort(projections) + points = points[sorted_idx] + projections = projections[sorted_idx] + + # Keep points that are sufficiently far apart + mask = np.ones(len(points), dtype=bool) + for i in range(1, len(points)): + if abs(projections[i] - projections[i - 1]) < 1e-8: + mask[i] = False + points = points[mask] + + return points + + +def calculate_cc_index(cc_contour: np.ndarray, plot: bool = False) -> float: + """Calculate CC index based on three thickness measurements. + + The AP line intersects the contour 4 times. The measurements are: + - Anterior thickness: distance between intersection points 1 and 2 + - Posterior thickness: distance between intersection points 3 and 4 + - Middle thickness: perpendicular line through midpoint of AP line + + The CC index is: (anterior + posterior + middle) / AP_length + + Parameters + ---------- + cc_contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + plot : bool, optional + Whether to generate a debug plot. Default is True. + + Returns + ------- + cc_index : float + The CC index, which is the sum of thicknesses at three measurement points divided by AP length. + """ + # Get anterior and posterior points (extremes along x-axis) + # In ACPC space, X is Anterior-Posterior direction, where Anterior is positive + posterior_idx = np.argmin(cc_contour[0]) # Minimum X is Posterior + anterior_idx = np.argmax(cc_contour[0]) # Maximum X is Anterior + + anterior_pt = cc_contour[:, anterior_idx] + posterior_pt = cc_contour[:, posterior_idx] + + # AP line vector from anterior to posterior + ap_vector = posterior_pt - anterior_pt + ap_length = np.linalg.norm(ap_vector) + ap_unit = ap_vector / ap_length + + # Perpendicular direction (90 degrees rotated) + perp_unit = np.array([-ap_unit[1], ap_unit[0]]) + + # Find where AP line intersects the contour (should be 4 points) + ap_intersections = get_intersections( + contour=cc_contour, start_point=anterior_pt, direction=ap_unit + ) + + if len(ap_intersections) != 4: + logger.error( + f"AP line should intersect contour exactly 4 times, " + f"but found {len(ap_intersections)} intersections" + ) + return 0.0 + + # Measurement 1: anterior thickness (between intersection points 1 and 2) + anterior_thickness = np.linalg.norm(ap_intersections[0] - ap_intersections[1]) + + # Measurement 2: posterior thickness (between intersection points 3 and 4) + posterior_thickness = np.linalg.norm(ap_intersections[2] - ap_intersections[3]) + + # AP distance is between outermost intersection points (1 and 4) + ap_distance = np.linalg.norm(ap_intersections[0] - ap_intersections[3]) + + # Midpoint of AP line (between points 1 and 4, or between anterior and posterior extremes) + midpoint = (ap_intersections[0] + ap_intersections[3]) / 2 + + # Measurement 3: perpendicular line through midpoint + middle_intersections = get_intersections( + contour=cc_contour, start_point=midpoint, direction=perp_unit + ) + + middle_thickness = np.linalg.norm(middle_intersections[0] - middle_intersections[-1]) + + # Calculate CC index + cc_index = (anterior_thickness + posterior_thickness + middle_thickness) / ap_distance + + if plot: + import matplotlib + import matplotlib.pyplot as plt + curr_backend = matplotlib.get_backend() + plt.switch_backend("qtagg") + + fig, ax = plt.subplots(figsize=(8, 6)) + plot_cc_index_calculation( + ax, + cc_contour, + anterior_idx, + posterior_idx, + ap_intersections, + middle_intersections, + midpoint, + ) + ax.legend() + plt.show() + plt.switch_backend(curr_backend) + + return cc_index + + +def plot_cc_index_calculation( + ax, + cc_contour: np.ndarray, + anterior_idx: int, + posterior_idx: int, + ap_intersections: np.ndarray, + middle_intersections: np.ndarray, + midpoint: np.ndarray, +) -> None: + """Plot the CC index measurements. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes to plot on. + cc_contour : np.ndarray + Array of shape (2, N) containing contour points in ACPC space. + anterior_idx : int + Index of the anterior point on the contour. + posterior_idx : int + Index of the posterior point on the contour. + ap_intersections : np.ndarray + Array of shape (4, 2) containing the 4 intersection points of the AP line with the contour. + middle_intersections : np.ndarray + Array of shape (2, 2) containing middle perpendicular intersection points. + midpoint : np.ndarray + Array of shape (2,) containing the midpoint of the AP line. + """ + from matplotlib.patches import PathPatch + from matplotlib.path import Path + + # Plot the CC contour (closed) + ax.plot(cc_contour[0], cc_contour[1], "k-", linewidth=1) + ax.plot( + [cc_contour[0, -1], cc_contour[0, 0]], + [cc_contour[1, -1], cc_contour[1, 0]], + "k-", + linewidth=1, + ) + + # Plot AP line through all 4 intersection points + ax.plot( + [ap_intersections[0, 0], ap_intersections[3, 0]], + [ap_intersections[0, 1], ap_intersections[3, 1]], + "r--", + linewidth=1, + label="AP line", + ) + + # Mark all 4 intersection points + for i, pt in enumerate(ap_intersections): + ax.scatter([pt[0]], [pt[1]], s=40, zorder=5) + ax.annotate(f"{i+1}", (pt[0], pt[1]), textcoords="offset points", + xytext=(5, 5), fontsize=10) + + # Plot anterior thickness (points 1-2) + ax.plot( + [ap_intersections[0, 0], ap_intersections[1, 0]], + [ap_intersections[0, 1], ap_intersections[1, 1]], + "b-", + linewidth=3, + label="Anterior thickness (1-2)", + ) + + # Plot posterior thickness (points 3-4) + ax.plot( + [ap_intersections[2, 0], ap_intersections[3, 0]], + [ap_intersections[2, 1], ap_intersections[3, 1]], + "c-", + linewidth=3, + label="Posterior thickness (3-4)", + ) + + # Plot middle thickness measurement (perpendicular) + ax.plot( + [middle_intersections[0, 0], middle_intersections[-1, 0]], + [middle_intersections[0, 1], middle_intersections[-1, 1]], + "g-", + linewidth=3, + label="Middle thickness", + ) + + # Mark midpoint + ax.scatter([midpoint[0]], [midpoint[1]], color="red", s=50, zorder=5, + marker="x", label="Midpoint") + + ax.set_aspect("equal") + + # Fill the contour with gray + contour_path = Path(cc_contour.T) + patch = PathPatch(contour_path, facecolor="gray", alpha=0.2, edgecolor=None) + ax.add_patch(patch) + + ax.invert_xaxis() + ax.axis("off") diff --git a/CorpusCallosum/shape/postprocessing.py b/CorpusCallosum/shape/postprocessing.py new file mode 100644 index 000000000..cfdb974ac --- /dev/null +++ b/CorpusCallosum/shape/postprocessing.py @@ -0,0 +1,614 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import concurrent.futures +from copy import copy +from functools import partial +from pathlib import Path +from typing import get_args + +import numpy as np +from nibabel.freesurfer.mghformat import MGHHeader +from numpy import typing as npt + +import FastSurferCNN.utils.logging as logging +from CorpusCallosum.data.constants import CC_LABEL, SUBSEGMENT_LABELS +from CorpusCallosum.shape.contour import CCContour +from CorpusCallosum.shape.curvature import calculate_curvature_metrics +from CorpusCallosum.shape.endpoint_heuristic import connect_diagonally_connected_components +from CorpusCallosum.shape.mesh import CCMesh +from CorpusCallosum.shape.metrics import calculate_cc_index +from CorpusCallosum.shape.subsegment_contour import ( + ContourList, + get_primary_eigenvector, + get_unique_contour_points, + hampel_subdivide_contour, + subdivide_contour, + subsegment_midline_orthogonal, + transform_to_acpc_standard, +) +from CorpusCallosum.utils.types import ( + CCMeasuresDict, + Points2dType, + SliceSelection, + SubdivisionMethod, +) +from CorpusCallosum.utils.visualization import plot_contours +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Mask2d, Shape2d, Shape3d, Vector2d +from FastSurferCNN.utils.common import SubjectDirectory, update_docstring +from FastSurferCNN.utils.parallel import process_executor, thread_executor + +logger = logging.get_logger(__name__) + +# assert LIA orientation +LIA_ORIENTATION = np.zeros((3,3)) +LIA_ORIENTATION[0,0] = -1 +LIA_ORIENTATION[1,2] = 1 +LIA_ORIENTATION[2,1] = -1 + + +def offset_affine(offset: npt.ArrayLike) -> AffineMatrix4x4: + """Generate an affine transformation matrix that only constitutes an offset (vector). + + Parameters + ---------- + offset : array_like + A 3-dimensional offset vector (shape (3,)) to offset with. + + Returns + ------- + np.ndarray + Modified 4x4 affine transformation matrix with the specific offset. + + Raises + ------ + TypeError + If offset is not a + """ + _offset = np.asarray(offset) + if not isinstance(_offset, np.ndarray) or _offset.shape != (3,): + raise TypeError("offset must convert to a ndarray of shape (3,)!") + vox2vox: AffineMatrix4x4 = np.eye(4, dtype=float) + vox2vox[0:3, 3] = _offset + return vox2vox + + +@update_docstring(SubdivisionMethod=str(get_args(SubdivisionMethod))[1:-1]) +def recon_cc_surf_measures_multi( + segmentation: np.ndarray[Shape3d, np.dtype[np.int_]], + slice_selection: SliceSelection, + upright_header: MGHHeader, + fsavg2midslab_vox2vox: AffineMatrix4x4, + fsavg_vox2ras: AffineMatrix4x4, + orig2fsavg_vox2vox: AffineMatrix4x4, + midslices: Image3d, + ac_coords_vox: Vector2d, + pc_coords_vox: Vector2d, + num_thickness_points: int, + subdivisions: list[float], + subdivision_method: SubdivisionMethod, + contour_smoothing: int, + subject_dir: SubjectDirectory, +) -> tuple[list[CCMeasuresDict], list[concurrent.futures.Future], list[CCContour], CCMesh | None]: + """Surface reconstruction and metrics computation of corpus callosum slices based on selection mode. + + Parameters + ---------- + segmentation : np.ndarray + 3D segmentation array in LIA orientation. + slice_selection : str + Which slices to process ('middle', 'all', or slice number). + upright_header : MGHHeader + The header of the upright image. + fsavg2midslab_vox2vox : AffineMatrix4x4 + The vox2vox transformation matrix from fsaverage (upright) space to the segmentation slab. + fsavg_vox2ras : np.ndarray + Base affine transformation matrix (fsaverage, upright space). + orig2fsavg_vox2vox : AffineMatrix4x4 + The transformation matrix from orig to fsaverage in voxel space. + midslices : np.ndarray + Array of mid-sagittal slices. + ac_coords_vox : np.ndarray + AC voxel coordinates with shape (2,) containing its [y,x] positions. + pc_coords_vox : np.ndarray + PC voxel coordinates with shape (2,) containing its [y,x] positions. + num_thickness_points : int + Number of points for thickness estimation. + subdivisions : list[float] + List of fractions for anatomical subdivisions. + subdivision_method : {SubdivisionMethod} + Method for contour subdivision. + contour_smoothing : int + Gaussian sigma for contour smoothing. + subject_dir : SubjectDirectory + The SubjectDirectory object managing file names in the subject directory. + + Returns + ------- + list of CCMeasuresDict + List of slice processing results. + list of concurrent.futures.Future + List of background IO processes. + list of CCContour + List of CC contours. + CCMesh, None + The CC mesh or None if no mesh was created. + """ + slice_cc_measures: list[CCMeasuresDict] = [] + io_futures = [] + + if subdivision_method == "angular" and not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): + raise ValueError( + f"Angular subdivision method (Hampel) only supports equidistant subdivision, " + f"but got: {subdivisions}. No measures are computed.", + ) + + _each_slice = partial( + recon_cc_surf_measure, + segmentation, + ac_coords_vox=ac_coords_vox, + pc_coords_vox=pc_coords_vox, + num_thickness_points=num_thickness_points, + subdivisions=subdivisions, + subdivision_method=subdivision_method, + contour_smoothing=contour_smoothing, + ) + + # Process multiple slices or specific slice + if slice_selection == "middle": + num_slices = 1 + # Process only the middle slice + slices_to_recon = [segmentation.shape[0] // 2] + elif slice_selection == "all": + num_slices = segmentation.shape[0] + start_slice = 0 + end_slice = segmentation.shape[0] + slices_to_recon = range(start_slice, end_slice) + else: # specific slice number + num_slices = 1 + slices_to_recon = [int(slice_selection)] + + def _gen_slice2slab_vox2vox(_slice_idx: int) -> AffineMatrix4x4: + # The slice_idx offset must be negative, because we are going from left to right. + return offset_affine([_slice_idx, 0, 0]) + + fsavg_midslab_vox2ras = fsavg_vox2ras @ np.linalg.inv(fsavg2midslab_vox2vox) + per_slice_vox2ras = fsavg_midslab_vox2ras @ np.stack(list(map(_gen_slice2slab_vox2vox, slices_to_recon)), axis=0) + + per_slice_recon = process_executor().map(_each_slice, slices_to_recon, per_slice_vox2ras, chunksize=1) + cc_contours = [] + + run = thread_executor().submit + wants_output = subject_dir.has_attribute + output_path = subject_dir.filename_by_attribute + slice_iterator = zip(slices_to_recon, per_slice_vox2ras, per_slice_recon, strict=True) + for i, (slice_idx, this_slice_vox2ras, _results) in enumerate(slice_iterator): + progress = f" ({i+1} of {num_slices})" if num_slices > 1 else "" + logger.info(f"Calculating CC measurements for slice {slice_idx+1}{progress}") + # unpack values from _results + cc_measures: CCMeasuresDict = _results[0] + _contour: CCContour = _results[1] + + cc_contours.append(_contour) + if cc_measures is None: + # this should not happen, but just in case + logger.warning(f"Slice index {slice_idx+1}{progress} returned result `None`") + + slice_cc_measures.append(cc_measures) + is_debug = logger.getEffectiveLevel() <= logging.DEBUG + is_midslice = slice_idx == num_slices // 2 + if wants_output("cc_qc_image") and (is_debug or is_midslice): + qc_imgs: list[Path] = [output_path("cc_qc_image")] + if is_debug: + qc_slice_img = qc_imgs[0].with_suffix(f".slice_{slice_idx}.png") + qc_imgs = (qc_imgs if is_midslice else []) + [qc_slice_img] + + logger.info(f"Saving segmentation qc image to {', '.join(map(str, qc_imgs))}") + current_slice_in_volume = midslices.shape[0] // 2 - num_slices // 2 + slice_idx + # Create visualization for this slice + io_futures.append( + run( + plot_contours, + # select the data of the current slice + slice_or_slab=midslices[[current_slice_in_volume]], + # the following need to be in voxel coordinates... + split_contours=cc_measures["split_contours"], + midline_equidistant=cc_measures["midline_equidistant"], + levelpaths=cc_measures["levelpaths"], + output_path=qc_imgs, + ac_coords_vox=ac_coords_vox, + pc_coords_vox=pc_coords_vox, + vox2ras=this_slice_vox2ras, + title=f"CC Subsegmentation by {subdivision_method} (Slice {slice_idx + 1})", + ) + ) + + if wants_output("save_template_dir"): + template_dir = output_path("save_template_dir") + # ensure directory exists + template_dir.mkdir(parents=True, exist_ok=True) + logger.info("Saving template files (contours.txt, thickness_values.txt, " + f"thickness_measurement_points.txt) to {template_dir}") + for j in range(len(cc_contours)): + io_futures.append(run(cc_contours[j].save_contour, template_dir / f"contour_{j}.txt")) + io_futures.append(run(cc_contours[j].save_thickness_values, template_dir / f"thickness_values_{j}.txt")) + + mesh_outputs = ("html", "mesh", "thickness_overlay", "surf", "thickness_image") + if len(cc_contours) > 1 and any(wants_output(f"cc_{n}") for n in mesh_outputs): + _cc_contours = thread_executor().map(_resample_thickness, cc_contours) + cc_mesh = CCMesh.from_contours(list(_cc_contours), smooth=1) + if wants_output("cc_html"): + logger.info(f"Saving CC 3D visualization to {output_path('cc_html')}") + io_futures.append(run(cc_mesh.plot_mesh, output_path=output_path("cc_html"))) + + if wants_output("cc_mesh"): + vtk_file_path = output_path("cc_mesh") + logger.info(f"Saving vtk file to {vtk_file_path}") + io_futures.append(run(cc_mesh.write_vtk, vtk_file_path)) + + if wants_output("cc_thickness_overlay") and not wants_output("cc_thickness_image"): + overlay_file_path = output_path("cc_thickness_overlay") + logger.info(f"Saving overlay file to {overlay_file_path}") + io_futures.append(run(cc_mesh.write_morph_data, overlay_file_path)) + + if any(wants_output(f"cc_{n}") for n in ("thickness_image", "surf")): + import nibabel as nib + up_data: Image3d[np.uint8] = np.empty(upright_header["dims"][:3], dtype=upright_header.get_data_dtype()) + upright_img = nib.MGHImage(up_data, fsavg_vox2ras, upright_header) + # the mesh is generated in upright coordinates, so we need to also transform to orig coordinates + # Mesh is fsavg_midplane (RAS); we need to transform to voxel coordinates + # fsavg ras is also on the midslice, so this is fine and we multiply in the IA and SP offsets + cc_mesh = cc_mesh.to_vox_coordinates(mesh_ras2vox=np.linalg.inv(fsavg_vox2ras @ orig2fsavg_vox2vox)) + if wants_output("cc_thickness_image"): + # this will also write overlay and surface + thickness_image_path = output_path("cc_thickness_image") + logger.info(f"Saving thickness image to {thickness_image_path}") + kwargs = { + "fssurf_file": output_path("cc_surf") if wants_output("cc_surf") else None, + "overlay_file": output_path("cc_thickness_overlay") + if wants_output("cc_thickness_overlay") else None, + "ref_image": upright_img, + } + cc_mesh.snap_cc_picture(thickness_image_path, **kwargs) + elif wants_output("cc_surf"): + surf_file_path = output_path("cc_surf") + logger.info(f"Saving surf file to {surf_file_path}") + io_futures.append(run(cc_mesh.write_fssurf, str(surf_file_path), image=upright_img)) + + if not slice_cc_measures: + logger.error("Error: No valid slices were found for postprocessing") + raise ValueError("No valid slices were found for postprocessing") + + return slice_cc_measures, io_futures, cc_contours, cc_mesh if len(cc_contours) > 1 else None + + +def _resample_thickness(contour: CCContour) -> CCContour: + """Resamples the thickness values of contour.""" + _c = copy(contour) + _c.fill_thickness_values() + return _c + + +def recon_cc_surf_measure( + segmentation: np.ndarray[Shape3d, np.dtype[np.int_]], + slice_idx: int, + slice_lia_vox2midslice_ras: AffineMatrix4x4, + ac_coords_vox: Vector2d, + pc_coords_vox: Vector2d, + num_thickness_points: int, + subdivisions: list[float], + subdivision_method: SubdivisionMethod, + contour_smoothing: int, +) -> tuple[CCMeasuresDict, CCContour]: + """Reconstruct surfaces and compute measures for a single slice for the corpus callosum. + + Parameters + ---------- + segmentation : np.ndarray + 3D segmentation array. + slice_idx : int + Index of the slice to process. + slice_lia_vox2midslice_ras : AffineMatrix4x4 + 4x4 affine transformation matrix. + ac_coords_vox : np.ndarray + AC voxel coordinates with shape (2,) containing its [y,x] positions. + pc_coords_vox : np.ndarray + PC voxel coordinates with shape (2,) containing its [y,x] positions. + num_thickness_points : int + Number of points for thickness estimation. + subdivisions : list[float] + List of fractions for anatomical subdivisions. + subdivision_method : SubdivisionMethod + Method for contour subdivision ('shape', 'vertical', 'angular', or 'eigenvector'). + contour_smoothing : int + Gaussian sigma for contour smoothing. + + Returns + ------- + measures : CCMeasuresDict + Dictionary containing measurements if successful. + contour : CCContour + The contour object containing points, thickness values, and endpoint indices. + + Raises + ------ + ValueError + If no CC is found in the specified slice. + + Notes + ----- + The function performs the following steps: + 1. Extracts CC contour and identifies endpoints. + 2. Converts coordinates to RAS space. + 3. Calculates thickness profile using Laplace equation. + 4. Computes shape metrics and subdivisions. + 5. Generates visualization data. + """ + cc_mask_slice: Mask2d = np.equal(segmentation[slice_idx], CC_LABEL) + if not np.any(cc_mask_slice): + raise ValueError(f"No CC found in slice {slice_idx}") + # clean up cc mask + cc_mask = connect_diagonally_connected_components(cc_mask_slice) + # create a CCContour from the cc_mask and transform to RAS coordinates + # - R coordinate is stored in _contour.z_position + # - AS coordinates are stored in _contour.points + _contour = CCContour.from_mask_and_acpc( + cc_mask, ac_coords_vox, pc_coords_vox, + slice_vox2ras=slice_lia_vox2midslice_ras, contour_smoothing=contour_smoothing, + ) + + levelpaths, thickness, midline_len, midline_equi, contour_with_thickness, endpoint_idxs, curvature = \ + _contour.create_levelpaths(num_thickness_points, inplace=True) + + contour_as = _contour.points.T + # thickness values in contour_with_thickness is not equally sampled, different shape + # to compute length of paths: diff between consecutive points (N-1, 2) => norm (N-1,) => sum (1,) + thickness_profile = np.stack([np.sum(np.linalg.norm(np.diff(x[:, :2], axis=0), axis=1)) for x in levelpaths]) + + acpc_contour_coords_as = contour_as[:, list(endpoint_idxs)].T + contour_in_acpc_space, ac_pt_acpc, pc_pt_acpc, rotate_back_acpc = transform_to_acpc_standard( + contour_as, + *acpc_contour_coords_as, + ) + cc_index = calculate_cc_index(contour_in_acpc_space) + + # Apply different subdivision methods based on user choice + split_contours: ContourList + split_points_midline: np.ndarray | None = None + if subdivision_method == "shape": + _subdivisions = np.asarray(subdivisions) + areas, split_contours, split_points_midline = subsegment_midline_orthogonal( + midline_equi, _subdivisions, contour_as, plot=False + ) + split_contours = [ + transform_to_acpc_standard(split_contour, *acpc_contour_coords_as)[0] + for split_contour in split_contours + ] + elif subdivision_method == "vertical": + areas, split_contours = subdivide_contour(contour_in_acpc_space, subdivisions, plot=False) + elif subdivision_method == "angular": + if not np.allclose(np.diff(subdivisions), np.diff(subdivisions)[0]): + raise ValueError( + f"Angular subdivision method (Hampel) only supports equidistant subdivision, " + f"but got: {subdivisions}. No measures are computed.", + ) + areas, split_contours = hampel_subdivide_contour(contour_in_acpc_space, num_rays=len(subdivisions), plot=False) + elif subdivision_method == "eigenvector": + pt0, pt1 = get_primary_eigenvector(contour_in_acpc_space) + contour_eigen, _, _, rotate_back_eigen = transform_to_acpc_standard(contour_in_acpc_space, pt0, pt1) + ac_pt_eigen, _, _, _ = transform_to_acpc_standard(ac_pt_acpc[:, None], pt0, pt1) + ac_pt_eigen = ac_pt_eigen[:, 0] + areas, split_contours = subdivide_contour(contour_eigen, subdivisions, oriented=True, hline_anchor=ac_pt_eigen) + split_contours = [rotate_back_eigen(split_contour) for split_contour in split_contours] + else: + raise ValueError(f"Invalid subdivision method {subdivision_method}") + + # order areas anterior to posterior + areas = areas[::-1] + + total_area = np.sum(areas) + # total_perimeter should include the edge from last to first point + contour_closed = np.concatenate([contour_as, contour_as[:, :1]], axis=1) + total_perimeter = np.sum(np.linalg.norm(np.diff(contour_closed, axis=1), axis=0)) + circularity = 4 * np.pi * total_area / (total_perimeter**2) + + # Transform split contours back to original space + split_contours = [rotate_back_acpc(split_contour) for split_contour in split_contours] + + # Calculate curvature metrics + curvature, curvature_body, curvature_subsegments = calculate_curvature_metrics( + midline_equi, split_points=split_points_midline, split_contours=split_contours + ) + + measures: CCMeasuresDict = { + "cc_index": cc_index, + "circularity": circularity, + "areas": np.asarray(areas), + "midline_length": midline_len, + "thickness": thickness, + "curvature": curvature, + "curvature_subsegments": curvature_subsegments, + "curvature_body": curvature_body, + "thickness_profile": thickness_profile, + "total_area": total_area, + "total_perimeter": total_perimeter, + "split_contours": split_contours, + "midline_equidistant": midline_equi, + "levelpaths": levelpaths, + "slice_index": slice_idx + } + return measures, _contour + + +def test_right_of_line( + coords: Points2dType, + line_start: Vector2d, + line_end: Vector2d, +) -> np.ndarray[tuple[int], np.dtype[np.bool_]]: + """Test whether points in coords are to the right of the line (line_start->line_end). + + Parameters + ---------- + coords : np.ndarray + Array of coordinates of shape (..., N). + line_start : array-like + [x, y] coordinates of line start point (N,). + line_end : array-like + [x, y] coordinates of line end point (N,). + + Returns + ------- + np.ndarray + Boolean array where True means point is to the left of the line of shape coords.shape[:-1]. + """ + # Vector from line_start to line_end + line_start_arr = np.expand_dims(line_start, axis=np.arange(line_start.ndim, coords.ndim).tolist()) + line_vec = np.expand_dims(line_end, axis=np.arange(line_end.ndim, coords.ndim).tolist()) - line_start_arr + + # Vectors from line_start to all points (vectorized) + point_vec = np.moveaxis(coords, -1, 0) - line_start_arr + + # Cross product (vectorized): positive means point is to the left of the line + cross_products = line_vec[0] * point_vec[1] - line_vec[1] * point_vec[0] + + return np.greater(cross_products, 0) + + +def make_subdivision_mask( + slice_shape: Shape2d, + split_contours: ContourList, + vox2ras: AffineMatrix4x4, + plot: bool = False, +) -> np.ndarray[Shape2d, np.dtype[np.int_]]: + """Create a mask for subdividing the corpus callosum based on split contours. + + Parameters + ---------- + slice_shape : pair of ints + Shape of the slice (rows, cols). + split_contours : ContourList + List of contours defining the subdivisions. + Each contour is a tuple of x and y coordinates. + vox2ras : AffineMatrix4x4 + The vox2ras transformation matrix for the requested shape. + plot : bool, default=False + Whether to plot the subdivision mask. + + Returns + ------- + np.ndarray + A mask of shape slice_shape where each pixel is labeled with a value + from SUBSEGEMNT_LABELS indicating which subdivision segment it belongs to. + + Notes + ----- + The function: + 1. Extracts unique contour points at subdivision boundaries. + 2. Creates coordinate grids for all points in the slice. + 3. Initializes mask with first segment label. + 4. For each subdivision line: + - Tests which points lie to the right of the line. + - Updates labels for those points. + """ + from nibabel.affines import apply_affine + + # unique_contour_points are the points where sub-division lines were inserted + unique_contour_points: list[Points2dType] = get_unique_contour_points(split_contours) # shape (N, 2) + subdivision_segments = unique_contour_points[1:] + + for s in subdivision_segments: + if len(s) != 2: + logger.error(f"Subdivision segment {s} has {len(s)} points, expected 2") + + # Create coordinate grids for all points in the slice + rows, cols = slice_shape + coords_vox = np.stack(np.mgrid[0:1, 0:rows, 0:cols], axis=-1) + coords_ras = apply_affine(vox2ras, coords_vox) + + # Use only as many labels as needed based on the number of subdivisions + # Number of regions = number of division lines + 1 + num_labels_needed = len(subdivision_segments) + 1 + cc_labels_posterior_to_anterior = SUBSEGMENT_LABELS[:num_labels_needed] + + # Initialize with first segment label + subdivision_mask = np.full(slice_shape, cc_labels_posterior_to_anterior[0], dtype=np.int32) + + # Process each subdivision line, subdivision_segments has for each division line the two points that are on the + # contour and divide the subsegments + for label, segment_points in zip(cc_labels_posterior_to_anterior[1:], reversed(subdivision_segments), strict=True): + # line_start and line_end are the intersection points of the CC subsegmentation boundary and the contour line + line_start, line_end = segment_points + + # --> find all voxels posterior to the line in question + # Vectorized test: find all points to the right of line (line_start->line_end) + # right_of_line == posterior to line + points_right_of_line = test_right_of_line(coords_ras[0, ..., 1:], line_start, line_end) + + # All points to the right of this line belong to the next segment or beyond + subdivision_mask[points_right_of_line] = label + + if plot: # interactive debug plot + import matplotlib + import matplotlib.pyplot as plt + curr_backend = matplotlib.get_backend() + plt.switch_backend("qtagg") + plt.figure(figsize=(10, 8)) + plt.imshow(subdivision_mask, cmap='tab10') + plt.colorbar(label='Subdivision') + plt.title('CC Subdivision Mask') + plt.xlabel('X') + plt.ylabel('Y') + plt.tight_layout() + plt.show() + plt.switch_backend(curr_backend) + return subdivision_mask + + +def check_area_changes(contours: list[np.ndarray], threshold: float = 0.3) -> bool: + """Check for large changes between consecutive CC areas and issue warnings. + + Parameters + ---------- + contours : list[np.ndarray] + List of contours (2, N). + threshold : float, default=0.3 + Threshold for relative change. + + Returns + ------- + bool + True if no large area changes are detected, False otherwise. + """ + + areas = np.asarray([np.abs(np.trapz(c[1], c[0])) for c in contours]) + + assert len(areas) > 1, "At least two areas are required to check for area changes" + + if np.any(areas == 0): + # One area is zero, the other is not - this is a 100% change + logger.warning(f"Areas {np.where(areas == 0)[0].tolist()} are zero mm²") + return False + + # Calculate relative change + relative_change = np.abs(np.diff(areas)) / areas[:-1] + + if np.any(where_change := relative_change > threshold): + indices = np.where(where_change)[0] + percent_change = relative_change[where_change] * 100 + logger.info( + f"Large corpus callosum area change after slices {indices.tolist()} detected: " + + ", ".join(f"areas {(i,i+1)} = ({areas[i]:.2f},{areas[i+1]:.2f}) mm² ({p:.1f}% change)" + for i, p in zip(indices, percent_change, strict=True)) + ) + return False + return True diff --git a/CorpusCallosum/shape/subsegment_contour.py b/CorpusCallosum/shape/subsegment_contour.py new file mode 100644 index 000000000..33fbd5a38 --- /dev/null +++ b/CorpusCallosum/shape/subsegment_contour.py @@ -0,0 +1,947 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from typing import Literal + +import matplotlib.pyplot as plt +import numpy as np +from scipy.spatial import ConvexHull + +from CorpusCallosum.utils.types import ContourList, Points2dType, Polygon2dType, Polygon3dType +from FastSurferCNN.utils import ScalarType, Vector2d + + +def minimum_bounding_rectangle(points: Points2dType) -> np.ndarray[tuple[Literal[4], Literal[2]], np.dtype[ScalarType]]: + """Find the smallest bounding rectangle for a set of points. + + Parameters + ---------- + points : array + An array of shape (N, 2) containing point coordinates. + + Returns + ------- + np.ndarray + Array of shape (4, 2) containing coordinates of the bounding box corners. + """ + pi2 = np.pi / 2.0 + points = np.asarray(points).T + + # get the convex hull for the points + hull_points = points[ConvexHull(points).vertices] + + # calculate edge angles + # including the edge that closes the loop from last to first point + edges = np.vstack([hull_points[1:] - hull_points[:-1], hull_points[0] - hull_points[-1]]) + + angles = np.arctan2(edges[:, 1], edges[:, 0]) + + angles = np.abs(np.mod(angles, pi2)) + angles = np.unique(angles) + + # find rotation matrices + rotations = np.vstack([np.cos(angles), np.cos(angles - pi2), np.cos(angles + pi2), np.cos(angles)]).T + rotations = rotations.reshape((-1, 2, 2)) + + # apply rotations to the hull + rot_points = np.dot(rotations, hull_points.T) + + # find the bounding points + min_x = np.nanmin(rot_points[:, 0], axis=1) + max_x = np.nanmax(rot_points[:, 0], axis=1) + min_y = np.nanmin(rot_points[:, 1], axis=1) + max_y = np.nanmax(rot_points[:, 1], axis=1) + + # find the box with the best area + areas = (max_x - min_x) * (max_y - min_y) + best_idx = np.argmin(areas) + + # return the best box + x1 = max_x[best_idx] + x2 = min_x[best_idx] + y1 = max_y[best_idx] + y2 = min_y[best_idx] + r = rotations[best_idx] + + rval = np.zeros((4, 2)) + rval[0] = np.dot([x1, y2], r) + rval[1] = np.dot([x2, y2], r) + rval[2] = np.dot([x2, y1], r) + rval[3] = np.dot([x1, y1], r) + + return rval + + +def calc_subsegment_areas(split_contours: ContourList) -> np.ndarray[tuple[int], np.dtype[ScalarType]]: + """Calculate area of each subsegment using the shoelace formula. + + Parameters + ---------- + split_contours : list of np.ndarray + List of contour arrays, each of shape (2, N). The list should contain + a set of nested contours (cumulative subsegments) and the full contour. + + Returns + ------- + subsegment_areas : array of floats + Array containing the area of each incremental subsegment. + """ + # calculate area of each split contour using the shoelace formula + # we use the absolute value because the orientation of the contour may vary + areas_cum = np.abs([np.trapz(c[1], c[0]) for c in split_contours]) + if len(areas_cum) == 1: + return np.asarray(areas_cum[0]) + + # Sort areas to ensure they are in increasing order of size + # This handles both cases where subsegments were provided in increasing or decreasing order + # The set of areas represents a sequence of nested shapes. + sorted_areas = np.sort(areas_cum) + + # Calculate the incremental pieces by taking differences between consecutive sizes + return np.diff(sorted_areas, prepend=0) + + +def subsegment_midline_orthogonal( + midline: Points2dType, + area_weights: np.ndarray[tuple[int], np.dtype[np.float_]], + contour: Polygon2dType, + plot: bool = True, + ax=None, + extremes=None, +) -> tuple[np.ndarray[tuple[int], np.dtype[ScalarType]], ContourList, np.ndarray]: + """Subsegment contour orthogonally to the midline based on area weights. + + Parameters + ---------- + midline : array of floats + Array of shape (N, 2) containing midline points. + area_weights : array of floats + Array of weights for area-based subdivision. + contour : array of floats + Array of shape (2, M) containing contour points in as space. + plot : bool, optional + Whether to plot the results, by default True. + ax : matplotlib.axes.Axes, optional + Axes for plotting, by default None. + extremes : tuple, optional + Tuple of extreme points, by default None. + + Returns + ------- + subsegment_areas : array of floats + List of subsegment areas. + split_contours : list of np.ndarray + List of contour arrays for each subsegment. + split_points : np.ndarray + Array of shape (K, 2) containing points where the midline was split. + + Notes + ----- + Subsegments include all previous segments. This means subsegment contour two is the outline of the union + of subsegment one and subsegment two. + """ + # FIXME: Here and in other places, the order of dimensions is pretty inconsistent, for example: midline is (N, 2), + # but contours are (2, N)... + + midline_end_idx = np.argmin(np.linalg.norm(contour.T - midline[-1], axis=1)) + # roll contour start to midline end + contour = np.roll(contour, -midline_end_idx, axis=1) + + # Calculate edge indices and fractions for splitting the midline + # We use len(midline) - 1 because we are looking for intervals between points + edge_idx_float = (len(midline) - 1) * np.array(area_weights) + edge_idx, edge_frac = np.divmod(edge_idx_float, 1) + edge_idx = edge_idx.astype(int) + + # Handle cases where area_weights might reach 1.0, which would lead to an out-of-bounds access + at_end = edge_idx >= len(midline) - 1 + edge_idx[at_end] = len(midline) - 2 + edge_frac[at_end] = 1.0 + + split_points = midline[edge_idx] + (midline[edge_idx + 1] - midline[edge_idx]) * edge_frac[:, None] + + # get edge for each split point + edge_directions = midline[edge_idx] - midline[edge_idx + 1] + # get vector perpendicular to each midline edge + edge_ortho_vectors = np.column_stack((-edge_directions[:, 1], edge_directions[:, 0])) + edge_ortho_vectors = edge_ortho_vectors / np.linalg.norm(edge_ortho_vectors, axis=1)[:, None] + + # Calculate intersections between the perpendicular lines and the contour + # vectors from split points to all contour points + vectors = contour.T[None, :, :] - split_points[:, None, :] # (K, M, 2) + + # Calculate cross product with ortho vectors to find side of the line (numerator of t) + # x*oy - y*ox + side = vectors[:, :, 0] * edge_ortho_vectors[:, None, 1] - vectors[:, :, 1] * edge_ortho_vectors[:, None, 0] + + # Find where the side changes sign, indicating an intersection + # Handle wrap-around by appending the first side value to the end + side_wrapped = np.hstack([side, side[:, 0:1]]) + sign_change = (side_wrapped[:, :-1] * side_wrapped[:, 1:]) <= 0 + + split_contours: ContourList = [contour] + + for pt_idx, split_point in enumerate(split_points): + # Indices of contour segments that have sign changes for this split point + seg_indices = np.where(sign_change[pt_idx])[0] + + intersections = [] + num_points = contour.shape[1] + for i in seg_indices: + s0 = side[pt_idx, i] + s1 = side[pt_idx, (i + 1) % num_points] + if s0 == s1: + t = 0.5 + else: + t = s0 / (s0 - s1) + + # intersection point on the segment + p0 = contour[:, i] + p1 = contour[:, (i + 1) % num_points] + intersection_point = p0 + t * (p1 - p0) + intersections.append((i, intersection_point)) + + + # get the two intersections closest to split_point + intersections.sort(key=lambda x: np.linalg.norm(x[1] - split_point)) + + # Create new contours by splitting at intersections + if len(intersections) >= 2: + first_index, first_intersection = intersections[1] + second_index, second_intersection = intersections[0] + + if first_index > second_index: + first_index, second_index = second_index, first_index + first_intersection, second_intersection = second_intersection, first_intersection + + first_index += 1 + + # connect first and second half + start_to_cutoff = np.hstack( + ( + contour[:, :first_index], + first_intersection[:, None], + second_intersection[:, None], + contour[:, second_index + 1 :], + ) + ) + split_contours.append(start_to_cutoff) + else: + raise ValueError(f"No intersections found for split point {pt_idx}, this should not happen") + + # plot contour to first index, then split point, then contour to second index + + # import matplotlib.pyplot as plt + # plt.close() + # fig, ax = plt.subplots(1,1) + # ax.plot(contour[:, :first_index][0], contour[:, :first_index][1], '-', linewidth=2, color='grey', + # label='Contour to first index') + # ax.plot(first_intersection[0], first_intersection[1], 'o', markersize=8, color='red', + # label='First intersection') + # ax.plot(second_intersection[0], second_intersection[1], 'o', markersize=8, color='red', + # label='Second intersection') + # ax.plot(contour[:, second_index + 1:][0], contour[:, second_index + 1:][1], '-', linewidth=2, color='red', + # label='Contour to second index') + # ax.legend() + # ax.set_title('Split Contours') + # ax.set_aspect('equal') + # ax.axis('off') + # plt.show() + + if plot: + extremes = [midline[0], midline[-1]] + + import matplotlib.pyplot as plt + + if ax is None: + SHOW = True + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis("equal") + else: + SHOW = False + # pretty plot with areas filled in the polygon and overall area annotated + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + # ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), f'{area_out[i]:.2f}', + # olor='black', fontsize=12) + # plot contour + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") + # put text between split points + # add endpoints to split_points + split_points = split_points.tolist() + split_points.insert(0, extremes[0]) + split_points.append(extremes[1]) + # ax.scatter(np.array(split_points)[:,0], np.array(split_points)[:,1], color='black', s=20) + ax.plot(midline[:, 0], midline[:, 1], "k--", linewidth=2) + + # plot edge orthogonal to each split point + for i in range(0, len(edge_ortho_vectors)): + pt = split_points[i + 1] + length = 0.4 + ax.plot( + [pt[0] - edge_ortho_vectors[i][0] * length, pt[0] + edge_ortho_vectors[i][0] * length], + [pt[1] - edge_ortho_vectors[i][1] * length, pt[1] + edge_ortho_vectors[i][1] * length], + "k-", + linewidth=2, + ) + + # convert area_weights into fraction of total line length + # e.g. area_weights=[1/6, 1/2, 2/3, 3/4] to ['1/6', '2/3', ...] + # cumulative difference + area_weights_diff = [area_weights[0]] + for i in range(1, len(area_weights)): + area_weights_diff.append(area_weights[i] - area_weights[i - 1]) + area_weights_diff.append(1 - area_weights[-1]) + + for i in range(len(split_points) - 1): + # get_index of split_points[i] in midline + sp1_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i], axis=1)) + sp2_midline_idx = np.argmin(np.linalg.norm(midline - split_points[i + 1], axis=1)) + + # get midpoint on midline + midpoint_idx = (sp1_midline_idx + sp2_midline_idx) // 2 + midpoint = midline[midpoint_idx] + + # get vector perpendicular to line between split points + vector = np.array(split_points[i + 1]) - np.array(split_points[i]) + vector = vector / np.linalg.norm(vector) + vector = np.array([-vector[1], vector[0]]) + + midpoint = midpoint - vector * 3 + # ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) + # ax.text(midpoint[0], midpoint[1], f'{area_weights_txt[i]}', color='black', fontsize=12, + # horizontalalignment='center', verticalalignment='center') + + # start point & end point + ax.plot(extremes[0][0], extremes[0][1], marker="o", markersize=8, color="black") + ax.plot(extremes[1][0], extremes[1][1], marker="o", markersize=8, color="black") + + # plot contour point 0 + # ax.scatter(contour[0,0], contour[1,0], color='red', s=120) + ax.set_title("Split Contours") + + if SHOW: + ax.axis("off") + ax.invert_xaxis() + ax.axis("equal") + plt.show() + + return calc_subsegment_areas(split_contours), split_contours, split_points + + +def get_unique_contour_points(split_contours: ContourList) -> list[Points2dType]: + """Get unique contour points from the split contours. + + Parameters + ---------- + split_contours : ContourList + List of split contours (subsegmentations), each containing x and y coordinates, each of shape (2, N). + + Returns + ------- + list[np.ndarray] + List of unique contour points for each subsegment, each of shape (N, 2). + + Notes + ----- + This is a workaround to retrospectively add voxel-based subdivision. + In the future, we could keep track of the subdivision lines for + every subdivision scheme. + + The function: + 1. Processes each contour point. + 2. Checks if it appears in other contours (with small tolerance). + 3. Collects points unique to each subsegment. + """ + # For each contour point, check if it appears in other contours + # initialize with values for first_contour, which are by definition just "the contour" (empty) + unique_contour_points: list[Points2dType] = [np.zeros((0, 2))] + first_contour = split_contours[0] + # Check each point against all other contours + for contour in split_contours[1:]: + # 0: coord-axis, 1: contour-axis, 2: first_contour_axis + contour_comparison = np.isclose(first_contour[:, None], contour[:, :, None], atol=1e-6) + # mask of contour points, that are also in first_contour (axis 1 after all) + contour_points_in_first_contour_mask = np.any(np.all(contour_comparison, axis=0), axis=1) + unique_contour_points.append(contour[:, ~contour_points_in_first_contour_mask].T) + + return unique_contour_points + + +def hampel_subdivide_contour(contour: Polygon2dType, num_rays: int, plot: bool = False, ax=None) \ + -> tuple[np.ndarray[tuple[int], np.dtype[np.float_]], ContourList]: + """Subdivide contour based on area weights using equally spaced rays. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) containing contour points. + num_rays : int + Number of rays to use for subdivision. + plot : bool, optional + Whether to plot the results, by default False. + ax : matplotlib.axes.Axes, optional + Axes for plotting, by default None. + + Returns + ------- + areas : np.ndarray + Array of areas for each subsegment. + split_contours : list[np.ndarray] + List of contour arrays for each subsegment. + + Notes + ----- + The subdivision process: + 1. Finds extreme points in x-direction. + 2. Creates minimal bounding rectangle around contour. + 3. Creates equally spaced rays from lower edge of rectangle. + 4. Finds intersections of rays with contour. + 5. Creates new contours by splitting at intersections. + 6. Returns areas and split contours. + """ + + # Find the extreme points in the x-direction + min_x_index = np.argmin(contour[0]) + contour = np.roll(contour, -min_x_index, axis=1) + + # get minimal bounding box around contour + min_bounding_rectangle = minimum_bounding_rectangle(contour) + + # get long edges of rectangle + rectangle_duplicate_last = np.vstack((min_bounding_rectangle, min_bounding_rectangle[0])) + long_edges = np.diff(rectangle_duplicate_last, axis=0) + long_edges = np.linalg.norm(long_edges, axis=1) + long_edges_idx = np.argpartition(long_edges, -2)[-2:] + + # select lower long edge + min_val = np.inf + min_idx = None + for i in long_edges_idx: + if rectangle_duplicate_last[i][1] < min_val: + min_val = rectangle_duplicate_last[i][1] + min_idx = i + + if rectangle_duplicate_last[i + 1][1] < min_val: + min_val = rectangle_duplicate_last[i + 1][1] + min_idx = i + + lowest_points = rectangle_duplicate_last[[min_idx, min_idx + 1]] + + # sort lowest points by x coordinate + if lowest_points[0, 0] < lowest_points[1, 0]: + lowest_points = lowest_points[::-1] + + # get midpoint of lower edge of rectangle + midpoint_lower_edge = np.mean(lowest_points, axis=0) + + # get angle of lower edge of rectangle to x-axis + angle_lower_edge = np.arctan2( + lowest_points[1, 1] - lowest_points[0, 1], lowest_points[1, 0] - lowest_points[0, 0] + ) + + # get angles for equally spaced rays + angles = np.linspace(-angle_lower_edge, -angle_lower_edge + np.pi, num_rays + 2, endpoint=True) # + np.pi *3 + angles = angles[1:-1] + + # create ray vectors + ray_vectors = np.vstack((np.cos(angles), np.sin(angles))) + # make ray vectors unit length + ray_vectors = ray_vectors / np.linalg.norm(ray_vectors, axis=0) + + # invert x of ray vectors + ray_vectors[0] = -ray_vectors[0] + + # Subdivision logic + split_contours: ContourList = [] + num_points = contour.shape[1] + for ray_vector in ray_vectors.T: + intersections = [] + for i in range(num_points): + segment_start = contour[:, i] + segment_end = contour[:, (i + 1) % num_points] + segment_vector = segment_end - segment_start + + # Check for intersection with the ray + matrix = np.array([segment_vector, -ray_vector]).T + if np.linalg.matrix_rank(matrix) < 2: + continue # Skip parallel lines + + # Solve for intersection + # matrix * [t, s]^T = midpoint_lower_edge - segment_start + try: + t, s = np.linalg.solve(matrix, midpoint_lower_edge - segment_start) + if 0 <= t < 1: # Use half-open interval to avoid double-counting vertices + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + except np.linalg.LinAlgError: + continue + + # Create new contours by splitting at intersections + if len(intersections) >= 2: + # Sort intersections by their position along the contour (index) + intersections.sort(key=lambda x: x[0]) + + # For HAMPEL (radial rays), we usually expect two intersections. + # If there are more, we pick the first and last along the contour. + first_index, first_intersection = intersections[0] + second_index, second_intersection = intersections[-1] + + start_to_cutoff = np.hstack( + ( + contour[:, :first_index], + first_intersection[:, None], + second_intersection[:, None], + contour[:, second_index + 1 :], + ) + ) + + # connect first and second half + split_contours.append(start_to_cutoff) + else: + raise ValueError("No intersections found, this should not happen") + + split_contours = [contour] + split_contours + + # Plotting logic + if plot: + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis("equal") + SHOW = True + else: + SHOW = False + min_bounding_rectangle_plot = np.vstack((min_bounding_rectangle, min_bounding_rectangle[0])) + # ax.plot(contour[0], contour[1], 'b-', label='Original Contour') + ax.plot(min_bounding_rectangle_plot[:, 0], min_bounding_rectangle_plot[:, 1], "k--") + ax.plot(midpoint_lower_edge[0], midpoint_lower_edge[1], "ko", markersize=8) + for ray_vector in ray_vectors.T: + ray_length = 15 + ray_vector *= -ray_length + ax.plot( + [midpoint_lower_edge[0], midpoint_lower_edge[0] + ray_vector[0]], + [midpoint_lower_edge[1], midpoint_lower_edge[1] + ray_vector[1]], + "k--", + ) + # pretty plot with areas files in the polygon and overall area annotated + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") + + ax.set_title("Split Contours") + ax.axis("off") + if SHOW: + ax.axis("equal") + plt.show() + + return calc_subsegment_areas(split_contours), split_contours + + +def subdivide_contour( + contour: Polygon2dType, + area_weights: list[float], + plot: bool = False, + ax: plt.Axes | None = None, + plot_transform: Callable | None = None, + oriented: bool = False, + hline_anchor: np.ndarray | None = None +) -> tuple[np.ndarray[tuple[int], np.dtype[np.float_]], ContourList]: + """Subdivide contour based on area weights using vertical lines. + + Divides the contour into segments by drawing vertical lines at positions + determined by the area weights. The lines are drawn perpendicular to a + reference line connecting the extreme points of the contour. + + Parameters + ---------- + contour : np.ndarray + Array of shape (2, N) containing contour points. + area_weights : list[float] + List of weights for area-based subdivision. + plot : bool, optional + Whether to plot the results, by default False. + ax : matplotlib.axes.Axes, optional + Axes for plotting, by default None. + plot_transform : callable, optional + Function to transform points before plotting, by default None. + oriented : bool, optional + If True, use fixed horizontal reference line, by default False. + hline_anchor : np.ndarray, optional + Point to anchor horizontal reference line, by default None. + + Returns + ------- + areas : np.ndarray + Array of areas for each subsegment. + split_contours : list[np.ndarray] + List of contour arrays for each subsegment. + + Notes + ----- + The subdivision process: + 1. Finds extreme points in x-direction. + 2. Creates reference line between extremes. + 3. Calculates split points based on area weights. + 4. Divides contour using perpendicular lines at split points. + + """ + # Find the extreme points in the x-direction + min_x_index = np.argmin(contour[0]) + contour = np.roll(contour, -min_x_index, axis=1) + + min_x_index = 0 + max_x_index = np.argmax(contour[0]) + + if oriented: + contour_x_sorted = np.sort(contour[0]) + min_x = contour_x_sorted[0] + max_x = contour_x_sorted[-1] + extremes = (np.array([min_x, 0]), np.array([max_x, 0])) + + if hline_anchor is not None: + extremes = (np.array([min_x, hline_anchor[1]]), np.array([max_x, hline_anchor[1]])) + else: + extremes = (contour[:, min_x_index].copy(), contour[:, max_x_index].copy()) + # Calculate the line between the extreme points + start_point, end_point = extremes + line_vector = end_point - start_point + line_length = np.linalg.norm(line_vector) + + # Normalize the line vector + line_unit_vector = line_vector / line_length + + # Calculate the perpendicular vector + perp_vector = np.array([-line_unit_vector[1], line_unit_vector[0]]) + perp_vector = perp_vector / np.linalg.norm(perp_vector) + + if hline_anchor is None: + most_inferior_point = np.min(contour[1]) + # move extreme 1 down 5 mm below inferior point and extreme 2 the + # same distance (so the angle stays the same) + down_distance = (extremes[1][1] - most_inferior_point) * 1.3 + start_point = extremes[0] + down_distance * perp_vector + end_point = extremes[1] + down_distance * perp_vector + + else: + # get closest point on line to hline_anchor + intersection = start_point + line_unit_vector * np.dot(hline_anchor - start_point, line_unit_vector) + # get distance closest point on line to hline_anchor + distance = np.linalg.norm(intersection - hline_anchor) + # move start and end point the same distance + start_point = extremes[0] + distance * perp_vector + end_point = extremes[1] + distance * perp_vector + + extremes = (start_point, end_point) + + # Calculate the line between the extreme points + start_point, end_point = extremes + line_vector = end_point - start_point + line_length = np.linalg.norm(line_vector) + + # Normalize the line vector + line_unit_vector = line_vector / line_length + + # Calculate the perpendicular vector + perp_vector = np.array([-line_unit_vector[1], line_unit_vector[0]]) + + # Calculate split points based on area weights + split_points = [] + for weight in area_weights: + # current_weight = np.sum(area_weights[:i]) + split_distance = weight * line_length + split_point = start_point + split_distance * line_unit_vector + split_points.append(split_point) + + # Split the contour at the calculated split points + split_contours = [] + split_contours.append(contour) + for split_point in split_points: + intersections = [] + for i in range(contour.shape[1] - 1): + segment_start = contour[:, i] + segment_end = contour[:, i + 1] + segment_vector = segment_end - segment_start + + # Check for intersection with the perpendicular line + matrix = np.array([segment_vector, -perp_vector]).T + if np.linalg.matrix_rank(matrix) < 2: + continue # Skip parallel lines + + # Solve for intersection + t, s = np.linalg.solve(matrix, split_point - segment_start) + if 0 <= t <= 1: + intersection_point = segment_start + t * segment_vector + intersections.append((i, intersection_point)) + + # Sort intersections by their position along the contour + # intersections.sort() + + # get the two intersections that have the highest y coordinate + intersections.sort(key=lambda x: x[1][1], reverse=True) + + # Create new contours by splitting at intersections + if intersections: + first_index, first_intersection = intersections[1] + second_index, second_intersection = intersections[0] + + if first_index > second_index: + first_index, second_index = second_index, first_index + first_intersection, second_intersection = second_intersection, first_intersection + + first_index += 1 + + # connect first and second half to create a closed cumulative loop + # that includes the start point of the contour (Posterior end) + start_to_cutoff = np.hstack( + ( + contour[:, :first_index], + first_intersection[:, None], + second_intersection[:, None], + contour[:, second_index + 1 :], + ) + ) + + # add cumulative subsegment + split_contours.append(start_to_cutoff) + else: + raise ValueError("No intersections found, this should not happen") + + if plot: + # make vline at every split point + split_points_vlines_start = (np.array(split_points) - perp_vector * 1).T + split_points_vlines_end = (np.array(split_points) + perp_vector * 1).T + + if oriented: + # make another vline at start point and end point, this time not + # perpendicular to line but perpendicular to x-axis + start_point_vline = np.array([start_point, np.array([start_point[0], start_point[1] + 8])]) + end_point_vline = np.array([end_point, np.array([end_point[0], end_point[1] + 8])]) + else: + start_point_vline = np.array([start_point, start_point - perp_vector * 8]) + end_point_vline = np.array([end_point, end_point - perp_vector * 8]) + + if plot_transform is not None: + split_contours = [plot_transform(split_contour) for split_contour in split_contours] + contour = plot_transform(contour) + extremes = [plot_transform(extreme[:, None]) for extreme in extremes] + split_points = [plot_transform(split_point[:, None]) for split_point in split_points] + split_points_vlines_start = plot_transform(split_points_vlines_start) + split_points_vlines_end = plot_transform(split_points_vlines_end) + start_point_vline = plot_transform(start_point_vline.T).T + end_point_vline = plot_transform(end_point_vline.T).T + + import matplotlib.pyplot as plt + + if ax is None: + SHOW = True + fig, ax = plt.subplots(1, 1, figsize=(8, 6)) + ax.axis("equal") + else: + SHOW = False + # pretty plot with areas filled in the polygon and overall area annotated + colors = plt.cm.Spectral(np.linspace(0.2, 0.8, len(split_contours))) + for color, split_contour in zip(colors, split_contours, strict=True): + ax.fill(split_contour[0], split_contour[1], alpha=0.5, color=color) + # ax.text(np.mean(split_contour[0]), np.mean(split_contour[1]), + # f'{area_out[i]:.2f}', color='black', fontsize=12) + # plot contour + ax.plot(contour[0], contour[1], "-", linewidth=2, color="grey") + # dashed line between start point & end point + ax.plot( + np.vstack((extremes[0][0], extremes[1][0])), + np.vstack((extremes[0][1], extremes[1][1])), + "--", + linewidth=2, + color="grey", + ) + # markers at every split point + for i in range(split_points_vlines_start.shape[1]): + ax.plot( + np.vstack((split_points_vlines_start[:, i][0], split_points_vlines_end[:, i][0])), + np.vstack((split_points_vlines_start[:, i][1], split_points_vlines_end[:, i][1])), + "k-", + linewidth=2, + ) + + ax.plot(start_point_vline[:, 0], start_point_vline[:, 1], "--", linewidth=2, color="grey") + ax.plot(end_point_vline[:, 0], end_point_vline[:, 1], "--", linewidth=2, color="grey") + # put text between split points + # add endpoints to split_points + split_points.insert(0, extremes[0]) + split_points.append(extremes[1]) + # convert area_weights into fraction of total line length + # e.g. area_weights=[1/6, 1/2, 2/3, 3/4] to ['1/6', '2/3', ...] + # cumulative difference + area_weights_diff = [] + area_weights_diff.append(area_weights[0]) + for i in range(1, len(area_weights)): + area_weights_diff.append(area_weights[i] - area_weights[i - 1]) + area_weights_diff.append(1 - area_weights[-1]) + + # area_weights_txt = area_weights_txt / area_weights_txt[-1] + from fractions import Fraction + + area_weights_txt = [ + Fraction(area_weights_diff[i]).limit_denominator(1000) for i in range(len(area_weights_diff)) + ] + + for i in range(len(split_points) - 1): + midpoint = np.mean([split_points[i], split_points[i + 1]], axis=0) + # ax.text(midpoint[0]-5, midpoint[1]-5, f'{area_out[i]:.2f}', color='black', fontsize=12) + ax.text( + midpoint[0], + midpoint[1] - 5, + f"{area_weights_txt[i]}", + color="black", + fontsize=11, + horizontalalignment="center", + ) + + # start point & end point + ax.plot(extremes[0][0], extremes[0][1], marker="o", markersize=8, color="black") + ax.plot(extremes[1][0], extremes[1][1], marker="o", markersize=8, color="black") + + # plot contour 0 point + # ax.scatter(contour[0,0], contour[1,0], color='red', s=100) + + ax.set_title("Split Contours") + # ax.set_xlabel('X') + # ax.set_ylabel('Y') + + # axis off + ax.axis("off") + if SHOW: + ax.axis("equal") + plt.show() + + return calc_subsegment_areas(split_contours), split_contours + + +def transform_to_acpc_standard( + contour_ras: Polygon2dType | Polygon3dType, + ac_pt_ras: Vector2d, + pc_pt_ras: Vector2d, +) -> tuple[Polygon2dType, Vector2d, Vector2d, Callable[[Polygon2dType], Polygon2dType]]: + """Transform contour coordinates to AC-PC standard space. + + Transforms the contour coordinates by: + 1. Translating AC point to origin. + 2. Rotating to align PC point with posterior direction. + 3. Scaling to maintain AC-PC distance. + + Parameters + ---------- + contour_ras : array of floats + Array of shape (2, N) or (3, N) containing contour points in RAS space. + ac_pt_ras : np.ndarray + Anterior commissure point coordinates in AS space. + pc_pt_ras : np.ndarray + Posterior commissure point coordinates in AS space. + + Returns + ------- + contour_acpc : np.ndarray + Transformed contour points in AC-PC space. + ac_pt_acpc : np.ndarray + AC point in AC-PC space (origin). + pc_pt_acpc : np.ndarray + PC point in AC-PC space. + rotate_back : callable + Function to transform points back to RAS space. + """ + # translate AC to the origin and PC to (0, ac_pc_dist) + translation_matrix = np.array([[1, 0, -ac_pt_ras[0]], [0, 1, -ac_pt_ras[1]], [0, 0, 1]]) + + ac_pc_vec: Vector2d = pc_pt_ras - ac_pt_ras + ac_pc_dist = np.linalg.norm(ac_pc_vec) + + posterior_vector: Vector2d = np.array([-ac_pc_dist, 0], dtype=float) + + # get angle between ac_pc_vec and posterior_vector + dot_product = np.dot(ac_pc_vec, posterior_vector) + norms_product = np.linalg.norm(ac_pc_vec) * np.linalg.norm(posterior_vector) + theta = np.arccos(dot_product / norms_product) + + # Determine the sign of the angle using cross product + cross_product = np.cross(ac_pc_vec, posterior_vector) + if cross_product < 0: + theta = -theta + + # create rotation matrix for theta + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) + + # apply translation and rotation + if contour_ras.shape[0] == 2: + contour_ras_homogeneous = np.vstack([contour_ras, np.ones(contour_ras.shape[1])]) + else: + contour_ras_homogeneous = contour_ras + + contour_acpc: Polygon2dType = (rotation_matrix @ translation_matrix) @ contour_ras_homogeneous + contour_acpc = contour_acpc[:2, :] + + def rotate_back(x: Polygon2dType) -> Polygon2dType: + return (np.linalg.inv(rotation_matrix @ translation_matrix) @ np.vstack([x, np.ones(x.shape[1])]))[:2, :] + + return contour_acpc, np.array([0, 0], dtype=float), np.array([-ac_pc_dist, 0], dtype=float), rotate_back + + +def get_primary_eigenvector(contour_ras: Polygon2dType) -> tuple[Vector2d, Vector2d]: + """Calculate primary eigenvector of contour points using PCA. + + Computes the principal direction of the contour by: + 1. Centering the points + 2. Computing covariance matrix + 3. Finding eigenvectors + 4. Selecting primary direction + + Parameters + ---------- + contour_ras : np.ndarray + Array of shape (2, N) containing contour points in RAS space. + + Returns + ------- + pt0 : np.ndarray + Starting point for eigenvector line. + pt1 : np.ndarray + End point for eigenvector line. + + """ + # Center the data by subtracting mean + contour_mean = np.mean(contour_ras, axis=1, keepdims=True) + contour_centered = contour_ras - contour_mean + + # Calculate covariance matrix + cov_matrix = np.cov(contour_centered) + + # Get eigenvalues and eigenvectors using PCA + eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) + + # Sort in descending order + idx = eigenvalues.argsort()[::-1] + eigenvectors = eigenvectors[:, idx] + + # make first eigenvector unit length + primary_eigenvector = eigenvectors[:, 0] / np.linalg.norm(eigenvectors[:, 0]) + pt0 = np.mean(contour_ras, axis=1) + pt0 -= np.array([0, 5]) + pt1 = pt0 + primary_eigenvector * 100 + + return pt0, pt1 + diff --git a/CorpusCallosum/shape/thickness.py b/CorpusCallosum/shape/thickness.py new file mode 100644 index 000000000..ff6ecb059 --- /dev/null +++ b/CorpusCallosum/shape/thickness.py @@ -0,0 +1,345 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Literal, overload + +import numpy as np +import scipy.interpolate +from lapy import Solver, TriaMesh +from lapy.diffgeo import compute_rotated_f +from meshpy import triangle + +from CorpusCallosum.shape.curvature import compute_mean_curvature +from CorpusCallosum.utils.types import ContourThickness, Points2dType +from FastSurferCNN.utils.common import suppress_stdout + + +def set_contour_zero_idx(contour, idx, anterior_endpoint_idx, posterior_endpoint_idx): + """Roll contour points to set a new zero index, while keeping track of CC endpoints. + + Parameters + ---------- + contour : np.ndarray + Array of contour points. + idx : int + New zero index. + anterior_endpoint_idx : int + Index of anterior endpoint. + posterior_endpoint_idx : int + Index of posterior endpoint. + + Returns + ------- + contour : np.ndarray + Rolled contour points. + anterior_endpoint_idx : int + Updated anterior endpoint index. + posterior_endpoint_idx : int + Updated posterior endpoint index. +""" + contour = np.roll(contour, -idx, axis=0) + anterior_endpoint_idx = (anterior_endpoint_idx - idx) % contour.shape[0] + posterior_endpoint_idx = (posterior_endpoint_idx - idx) % contour.shape[0] + return contour, anterior_endpoint_idx, posterior_endpoint_idx + + +def find_closest_edge(point, contour): + """Find the index of the edge closest to the given point. + + Parameters + ---------- + point : np.ndarray + 2D point coordinates. + contour : np.ndarray + Array of shape (N, 2) containing contour points. + + Returns + ------- + int + Index of the closest edge. + """ + edges_start = contour[:, :2] # N x 2 + edges_end = np.roll(contour[:, :2], -1, axis=0) # N x 2 + edges_vec = edges_end - edges_start # N x 2 + + # Calculate projection coefficient for all edges at once + # (p-a)·(b-a) / |b-a|² + edge_lengths_sq = np.sum(edges_vec * edges_vec, axis=1) + # Avoid division by zero for degenerate edges + valid_edges = edge_lengths_sq > 1e-10 + t = np.zeros(len(edges_start)) + t[valid_edges] = ( + np.sum((point - edges_start[valid_edges]) * edges_vec[valid_edges], axis=1) + / edge_lengths_sq[valid_edges] + ) + t = np.clip(t, 0, 1) # Clamp to edge endpoints + + # Get closest points on all edges + closest_points = edges_start + t[:, None] * edges_vec + + # Calculate distances to all edges + distances = np.linalg.norm(point - closest_points, axis=1) + + # Return index of closest edge + return np.argmin(distances) + + +@overload +def insert_point_with_thickness( + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, + point: np.ndarray, + thickness_value: float, + return_index: Literal[False] = False, +) -> tuple[np.ndarray, np.ndarray]: ... + + +@overload +def insert_point_with_thickness( + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, + point: np.ndarray, + thickness_value: float, + return_index: Literal[True], +) -> tuple[np.ndarray, np.ndarray, int] | list[np.ndarray, np.ndarray]: + ... + + +def insert_point_with_thickness( + contour_in_as_space: np.ndarray, + contour_thickness: np.ndarray, + point: np.ndarray, + thickness_value: float, + return_index: bool = False +) -> tuple[np.ndarray, np.ndarray, int] | tuple[np.ndarray, np.ndarray]: + """Inserts a point and its thickness value into the contour. + + Parameters + ---------- + contour_in_as_space : np.ndarray + Array of coordinates of the contour in AS space, shape (N, 2). + contour_thickness : np.ndarray + Array of thickness values of the contour, shape (N,). + point : np.ndarray + 2D point to insert, shape (2,). + thickness_value : float + Thickness value corresponding to the point. + return_index : bool, default=False + If True, return the index where point was inserted, by default False. + + Returns + ------- + contour_in_as_space : np.ndarray + Updated contour of shape (N+1, 2). + contour_thickness : np.ndarray + Updated thickness values of shape (N+1,). + insertion_index : int + The index, where the point was inserted (only if return_index is True). + """ + # Find closest edge for the point + edge_idx = find_closest_edge(point, contour_in_as_space) + + # Insert point between edge endpoints + contour_in_as_space = np.insert(contour_in_as_space, edge_idx + 1, point, axis=0) + contour_thickness = np.insert(contour_thickness, edge_idx + 1, thickness_value) + + if return_index: + return contour_in_as_space, contour_thickness, edge_idx + 1 + else: + return contour_in_as_space, contour_thickness + + +def make_mesh_from_contour( + contour_2d: np.ndarray, + max_volume: float = 0.5, + min_angle: float = 25, + verbose: bool = False +) -> tuple[Points2dType[np.float_], np.ndarray[tuple[int, Literal[3]], np.dtype[np.int_]]]: + """Create a triangular mesh from a 2D contour. + + Parameters + ---------- + contour_2d : np.ndarray + Array of shape (N, 2) containing contour points. + max_volume : float, optional + Maximum triangle area, by default 0.5. + min_angle : float, optional + Minimum angle in triangles (degrees), by default 25. + verbose : bool, optional + Whether to print mesh generation info, by default False. + + Returns + ------- + mesh_points : np.ndarray + Array of shape (M, 2) containing mesh vertices. + mesh_trias : np.ndarray + Array of shape (K, 3) containing triangle indices. + + Notes + ----- + Uses meshpy.triangle to create a constrained Delaunay triangulation + of the contour. The contour must not have duplicate points. + """ + + facets = np.vstack((np.arange(len(contour_2d)), ((np.arange(len(contour_2d)) + 1) % len(contour_2d)))).T + + # use meshpy to create mesh + info = triangle.MeshInfo() + info.set_points(contour_2d) # needs to be (N, D) + info.set_facets(facets) + # NOTE: crashes if contour has duplicate points !! + mesh = triangle.build(info, max_volume=max_volume, min_angle=min_angle, verbose=verbose) + + mesh_points: Points2dType[np.float_] = np.array(mesh.points, dtype=float) + mesh_trias: np.ndarray[tuple[int, Literal[3]], np.dtype[np.int_]] = np.array(mesh.elements, dtype=int) + + return mesh_points, mesh_trias + + +def cc_thickness( + contour_2d: Points2dType, + endpoint_idx: tuple[int, int], + n_points: int = 100, +) -> tuple[float, float, float, Points2dType , list[Points2dType], ContourThickness, tuple[int, int]]: + """Calculate corpus callosum thickness using Laplace equation. + + Parameters + ---------- + contour_2d : np.ndarray + Array of shape (N, 2) containing contour points. + endpoint_idx : pair of ints + Indices of anterior and posterior endpoints in contour. + n_points : int, default=100 + Number of points for thickness measurement. + + Returns + ------- + midline_length : float + Total length of the midline. + thickness : float + Mean thickness across all level paths. + curvature : float + Mean absolute curvature in degrees. + midline_equidistant : np.ndarray + Equidistant points along the midline in same space as contour2d of shape (N, 2). + levelpaths : list[np.ndarray] + Level paths for thickness measurement in same space as contour2d, each of shape (N, 2). + contour_with_thickness : np.ndarray + Contour coordinates with thickness information in same space as contour2d of shape (N+2, 3). + endpoint_indices : pair of ints + Pair of updated indices of anterior and posterior endpoint. + + Notes + ----- + Uses the Laplace equation to compute thickness by: + 1. Creating a triangular mesh from the contour + 2. Setting boundary conditions (0 at endpoints, ±1 on sides) + 3. Solving Laplace equation to get level sets + 4. Computing thickness along level sets + """ + anterior_endpoint_idx, posterior_endpoint_idx = endpoint_idx + + # standardize contour indices to start at anterior_endpoint_idx, to get consistent levelpath directions + contour_2d, anterior_endpoint_idx, posterior_endpoint_idx = set_contour_zero_idx( + contour_2d, anterior_endpoint_idx, anterior_endpoint_idx, posterior_endpoint_idx, + ) + + mesh_points_contour_space, mesh_trias = make_mesh_from_contour(contour_2d) + + # make points 3D by appending z=0, asz space therefore is the contour space (usually AS space) with a zero z-dim + mesh_points_asz = np.append(mesh_points_contour_space, np.zeros((mesh_points_contour_space.shape[0], 1)), axis=1) + + # compute poisson + with suppress_stdout(): + tria_asz = TriaMesh(mesh_points_asz, mesh_trias) + # extract boundary curve + bdr = np.array(tria_asz.boundary_loops()[0]) + + # find index of endpoints in bdr list + iidx1 = np.where(bdr == anterior_endpoint_idx)[0][0] + iidx2 = np.where(bdr == posterior_endpoint_idx)[0][0] + + # create boundary condition (0 at endpoints, -1 on one side, 1 on the other): + if iidx1 > iidx2: + iidx1, iidx2 = iidx2, iidx1 + dcond = np.ones(bdr.shape) + dcond[iidx1] = 0 + dcond[iidx2] = 0 + dcond[iidx1 + 1 : iidx2] = -1 + + # Extract path + fem = Solver(tria_asz) + vfunc = fem.poisson(0, (bdr, dcond)) + midline_length: float + midline_equidistant_asz, midline_length = tria_asz.level_path(vfunc, level=0., n_points=n_points + 2) + midline_equidistant_contour_space: np.ndarray = midline_equidistant_asz[:, :2] + + gf = compute_rotated_f(tria_asz, vfunc) + + # interpolate midline to get levels to evaluate + level_of_rotated_laplace_contour_space = scipy.interpolate.griddata( + tria_asz.v[:, 0:2], gf, midline_equidistant_asz[:, 0:2], method="cubic", + ) + + # get levels to evaluate + levelpaths_contour_space: list[Points2dType] = [] + levelpath_lengths = [] + levelpath_tria_idx = [] + + # now, on the rotated laplace function, sample equally spaced (on midline: level_of_rotated_laplace) levelpaths + contour_thickness = np.full(contour_2d.shape[0], np.nan) + for current_level in level_of_rotated_laplace_contour_space[1:-1]: + # levelpath starts at index zero + levelpath_asz, lvlpath_length, tria_idx = tria_asz.level_path(gf, current_level, get_tria_idx=True) + + levelpaths_contour_space.append(levelpath_asz[:, :2]) + levelpath_lengths.append(lvlpath_length) + levelpath_tria_idx.append(tria_idx) + + levelpath_start = levelpath_asz[0, :2] + levelpath_end = levelpath_asz[-1, :2] + + contour_2d, contour_thickness, inserted_idx_start = insert_point_with_thickness( + contour_2d, contour_thickness, levelpath_start, lvlpath_length, return_index=True, + ) + # keep track of start index + if inserted_idx_start <= anterior_endpoint_idx: + anterior_endpoint_idx += 1 + if inserted_idx_start <= posterior_endpoint_idx: + posterior_endpoint_idx += 1 + + contour_2d, contour_thickness, inserted_idx_end = insert_point_with_thickness( + contour_2d, contour_thickness, levelpath_end, lvlpath_length, return_index=True, + ) + # keep track of end index + if inserted_idx_end <= anterior_endpoint_idx: + anterior_endpoint_idx += 1 + if inserted_idx_end <= posterior_endpoint_idx: + posterior_endpoint_idx += 1 + + contour_2d_with_thickness = np.concatenate([contour_2d, contour_thickness[:, None]], axis=1) + + # get curvature of path3d_resampled + mean_curvature: float = compute_mean_curvature(midline_equidistant_contour_space) + mean_thickness: float = np.mean(levelpath_lengths).item() + endpoints: tuple[int, int] = (anterior_endpoint_idx, posterior_endpoint_idx) + + return ( + midline_length, + mean_thickness, + mean_curvature, + midline_equidistant_contour_space, + levelpaths_contour_space, + contour_2d_with_thickness, + endpoints, + ) diff --git a/CorpusCallosum/transforms/__init__.py b/CorpusCallosum/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/transforms/localization.py b/CorpusCallosum/transforms/localization.py new file mode 100644 index 000000000..e129fc820 --- /dev/null +++ b/CorpusCallosum/transforms/localization.py @@ -0,0 +1,153 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from logging import getLogger + +import numpy as np +import torch +from monai.transforms import MapTransform, RandomizableTransform + + +class CropAroundACPCFixedSize(RandomizableTransform, MapTransform): + """Crop image around AC-PC points with fixed size. + + A transform that crops the input image around the midpoint between + AC and PC points with a fixed size window and optional random translation. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to. + fixed_size : tuple[int, int] + Fixed size of the crop window (width, height). + allow_missing_keys : bool, optional + Whether to allow missing keys in the data dictionary, by default False. + random_translate : int, default=0 + Maximum random translation in voxels. + + Raises + ------ + ValueError + If the crop boundaries extend outside the image dimensions. + + Notes + ----- + The transform expects the following keys in the data dictionary: + + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - image : np.ndarray + Input image to crop + + """ + + def __init__( + self, + keys: list[str], + fixed_size: tuple[int, int], + allow_missing_keys: bool = False, + random_translate: int = 0, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self) + self.random_translate = random_translate + self.fixed_size = fixed_size + + def __call__(self, data: dict) -> dict: + """Apply the 2D crop transform to the data. + + Parameters + ---------- + data : dict + Dictionary containing the data to transform AND keys AC_center and PC_center, each of shape (B, 2). + + Returns + ------- + dict + Transformed data dictionary with cropped images and updated coordinates. + Also includes crop boundary information: + - crop_left : list[int] + - crop_right : list[int] + - crop_top : list[int] + - crop_bottom : list[int] + + Raises + ------ + ValueError + If crop boundaries extend outside the image dimensions + """ + d = dict(data) + + expected_keys = {"PC_center", "AC_center"} | set(self.keys) if not self.allow_missing_keys else {} + + if expected_keys & set(d.keys()) != expected_keys: + raise ValueError(f"The following keys are missing in the data dictionary: {expected_keys - set(d.keys())}!") + + if any(d[k].ndim != 2 or d[k].shape[1] != 2 for k in ["PC_center", "AC_center"]): + raise ValueError("Shape of AC_center or PC_center incorrect, must be (B, 2)!") + + if any(d[k].ndim != 4 for k in self.keys if k in d.keys()): + raise ValueError(f"At least one key of {self.keys} does not have a 4-dimensional tensor.") + + # calculate center point between AC and PC + center_point = ((d['AC_center'] + d['PC_center']) / 2).astype(int) + + # Calculate voxel padding based on mm padding + voxel_padding = np.asarray(self.fixed_size) // 2 + + existing_keys = set(self.keys) & set(d.keys()) + if len(existing_keys) == 0: + getLogger(__name__).warning(f"None of the keys in {self.keys} are present in the data dictionary!") + return d + + first_key = tuple(existing_keys)[0] + + # Calculate crop boundaries with padding and random translation + crops = center_point - voxel_padding + + # Add random translation if specified + if self.random_translate > 0: + crops += np.random.randint( + -self.random_translate, + self.random_translate + 1, + size=(d[first_key].shape[0], 2), + ) + + # Ensure crop boundaries are within image + img_shape = np.asarray(d[first_key].shape[2:]) # Get spatial dimensions + if any(np.any(img_shape != d[k].shape[2:]) for k in self.keys if k in d.keys()): + raise ValueError(f"At least one key of {self.keys} does not have the expected shape.") + + patch_size_with_batch_dim = np.asarray(self.fixed_size)[None] + crops = np.maximum(0, np.minimum(img_shape, crops + patch_size_with_batch_dim) - patch_size_with_batch_dim) + d["crop_left"], d["crop_top"] = crops.T.tolist() + d["crop_right"], d["crop_bottom"] = (crops_end := crops + patch_size_with_batch_dim).T.tolist() + + # raise error if crop boundaries are out of image + if np.any(crops < 0) or np.any(crops_end > np.asarray([d[first_key].shape[2:]])): + raise ValueError("Crop boundaries are out of image") + + # Apply crop to image + for key in self.keys: + if key not in d.keys() and self.allow_missing_keys: + continue + arr = [v[:, cl:cr, ct:cb] for v, cl, ct, cr, cb in zip(d[key], *crops.T, *crops_end.T, strict=True)] + d[key] = torch.stack(arr, dim=0) if torch.is_tensor(arr[0]) else np.stack(arr, axis=0) + + # Update point coordinates relative to cropped image + d["PC_center"] = d["PC_center"] - crops + d["AC_center"] = d["AC_center"] - crops + return d diff --git a/CorpusCallosum/transforms/segmentation.py b/CorpusCallosum/transforms/segmentation.py new file mode 100644 index 000000000..26943e255 --- /dev/null +++ b/CorpusCallosum/transforms/segmentation.py @@ -0,0 +1,180 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Literal + +import numpy as np +from monai.transforms import MapTransform, RandomizableTransform + + +class CropAroundACPC(RandomizableTransform, MapTransform): + """Crop image around anterior and posterior commissure points. + + A transform that crops the input image around the AC and PC points with + optional padding and random translation. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to. + allow_missing_keys : bool, default=False + Whether to allow missing keys in the data dictionary. + padding_mm : float, default=10.0 + Padding around AC-PC region in millimeters. + random_translate : float, default=0 + Maximum random translation in voxels, off by default. + + Notes + ----- + The transform expects the following keys in the data dictionary: + + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - res : float + Voxel resolution in mm + + """ + + def __init__(self, keys: list[str], allow_missing_keys: bool = False, + padding_mm: float = 10, random_translate: float = 0) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob=1, do_transform=True) + self.padding_mm = padding_mm + self.random_translate = random_translate + + def __call__(self, data: dict) -> dict: + """Apply the transform to the data. + + Parameters + ---------- + data : dict + Dictionary containing the data to transform. + + Returns + ------- + dict + Transformed data dictionary. + """ + d = dict(data) + + if "AC_center_original" not in d: + d["AC_center_original"] = d["AC_center"].copy() + if "PC_center_original" not in d: + d["PC_center_original"] = d["PC_center"].copy() + + if self.random_translate > 0: + random_translate = np.random.randint(-self.random_translate, self.random_translate, size=2) + else: + random_translate = (0,0,0) + + pc_center = d["PC_center"] + ac_center = d["AC_center"] + + ac_pc = np.stack([ac_center, pc_center], axis=0) + + ac_pc_bottomleft = np.min(ac_pc, axis=0).astype(int) + ac_pc_topright = np.max(ac_pc, axis=0).astype(int) + + voxel_padding: np.ndarray[tuple[Literal[2]], np.dtype[np.int_]] = np.round( + self.padding_mm / d["res"]).astype(int) + + crop_left = ac_pc_bottomleft[1] - int(voxel_padding[0] * 1.5) + random_translate[0] + crop_right = ac_pc_topright[1] + voxel_padding[0] // 2 + random_translate[0] + crop_top = ac_pc_bottomleft[2] - voxel_padding[1] + random_translate[1] + crop_bottom = ac_pc_topright[2] + voxel_padding[1] + random_translate[1] + + keys_to_process = [key for key in self.keys if key in d.keys()] + + if not self.allow_missing_keys and set(keys_to_process) != set(self.keys): + raise ValueError("Some keys are missing in the data dictionary.") + + if len(keys_to_process) == 0: + logging.getLogger(__name__).warning("No keys to process.") + return d + + first_key = keys_to_process[0] + d["to_pad"] = crop_left, d[first_key].shape[2] - crop_right, crop_top, d[first_key].shape[3] - crop_bottom + + for key in keys_to_process: + d[key] = d[key][:, :, crop_left:crop_right, crop_top:crop_bottom] + + return d + + +class CropAroundACPCtrack(CropAroundACPC): + """Crop image around AC-PC points and update their coordinates. + + Extends CropAroundACPC to also adjust the AC and PC center coordinates + after cropping to maintain their correct positions in the cropped image. + + Parameters + ---------- + keys : list[str] + Keys of the data dictionary to apply the transform to. + allow_missing_keys : bool, optional + Whether to allow missing keys in the data dictionary, by default False. + padding_mm : float, optional + Padding around AC-PC region in millimeters, by default 10. + random_translate : float, optional + Maximum random translation in voxels, by default 0. + + Notes + ----- + The transform expects the following keys in the data dictionary: + + - AC_center : np.ndarray + Coordinates of anterior commissure + - PC_center : np.ndarray + Coordinates of posterior commissure + - AC_center_original : np.ndarray + Original coordinates of anterior commissure + - PC_center_original : np.ndarray + Original coordinates of posterior commissure + + """ + + def __call__(self, data: dict) -> dict: + """Apply the transform to the data. + + Parameters + ---------- + data : dict + Dictionary containing the data to transform. + + Returns + ------- + dict + Transformed data dictionary with updated AC and PC coordinates. + """ + + + # First call parent class to get cropped image + d = super().__call__(data) + + # Get the crop coordinates that were used + pad_left, pad_right, pad_top, pad_bottom = d["to_pad"] + + # Adjust AC and PC center coordinates based on cropping + if "AC_center" in d: + d["AC_center"][1] = d["AC_center_original"][1] - pad_left.item() + d["AC_center"][2] = d["AC_center_original"][2] - pad_top.item() + + if "PC_center" in d: + d["PC_center"][1] = d["PC_center_original"][1] - pad_left.item() + d["PC_center"][2] = d["PC_center_original"][2] - pad_top.item() + + return d + diff --git a/CorpusCallosum/utils/__init__.py b/CorpusCallosum/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/CorpusCallosum/utils/checkpoint.py b/CorpusCallosum/utils/checkpoint.py new file mode 100644 index 000000000..355542bd1 --- /dev/null +++ b/CorpusCallosum/utils/checkpoint.py @@ -0,0 +1,17 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from FastSurferCNN.utils.parser_defaults import FASTSURFER_ROOT + +YAML_DEFAULT = FASTSURFER_ROOT / "CorpusCallosum/config/checkpoint_paths.yaml" diff --git a/CorpusCallosum/utils/mapping_helpers.py b/CorpusCallosum/utils/mapping_helpers.py new file mode 100644 index 000000000..1d76cd557 --- /dev/null +++ b/CorpusCallosum/utils/mapping_helpers.py @@ -0,0 +1,397 @@ +from pathlib import Path +from typing import overload + +import nibabel as nib +import numpy as np +import SimpleITK as sitk +from scipy.ndimage import affine_transform + +from CorpusCallosum.data.constants import CC_LABEL, FORNIX_LABEL +from CorpusCallosum.utils.types import Polygon3dType +from FastSurferCNN.utils import ( + AffineMatrix4x4, + Image2d, + Image3d, + Image4d, + RotationMatrix3x3, + Shape3d, + Vector2d, + Vector3d, + logging, + nibabelImage, +) +from FastSurferCNN.utils.parallel import thread_executor + +logger = logging.get_logger(__name__) + + +def make_midplane_affine( + orig_affine: AffineMatrix4x4, + slices_to_analyze: int = 1, + offset: int = 4, + ) -> AffineMatrix4x4: + """Create affine transformation matrix for midplane slices. + + Parameters + ---------- + orig_affine : AffineMatrix4x4 + Original image affine matrix (4x4). + slices_to_analyze : int, default=1 + Number of slices to analyze around midplane. + offset : int, default=4 + Additional offset in x direction. + + Returns + ------- + AffineMatrix4x4 + 4x4 affine matrix for midplane slices. + """ + # Create translation matrix to center on midplane + orig_to_seg = np.eye(4) + orig_to_seg[0, 3] = -256 // 2 + slices_to_analyze // 2 + offset + + # Combine with original affine + seg_affine = orig_affine @ np.linalg.inv(orig_to_seg) + + return seg_affine + + +def correct_nodding(ac_pt: Vector2d, pc_pt: Vector2d) -> RotationMatrix3x3: + """Calculate rotation matrix to correct head nodding. + + Calculates rotation matrix to align AC-PC line with posterior direction, + correcting for head nodding based on AC-PC line orientation. + + Parameters + ---------- + ac_pt : Vector2d + 2D coordinates of the anterior commissure point. + pc_pt : Vector2d + 2D coordinates of the posterior commissure point. + + Returns + ------- + RotationMatrix + 3x3 rotation matrix to align AC-PC line with posterior direction. + """ + ac_pc_vec = pc_pt - ac_pt + ac_pc_dist = np.linalg.norm(ac_pc_vec) + + posterior_vector = np.array([0, -ac_pc_dist]) + + # get angle between ac_pc_vec and posterior_vector + dot_product = np.dot(ac_pc_vec, posterior_vector) + norms_product = np.linalg.norm(ac_pc_vec) * np.linalg.norm(posterior_vector) + theta = np.arccos(dot_product / norms_product) + + # Determine the sign of the angle using cross product + cross_product = np.cross(ac_pc_vec, posterior_vector) + if cross_product < 0: + theta = -theta + + # create rotation matrix for theta + rotation_matrix: RotationMatrix3x3 = np.array( + [ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1], + ] + ) + + return rotation_matrix + + +@overload +def apply_transform_to_pt(pts: Vector3d, T: AffineMatrix4x4, inv: bool = False) -> Vector3d: ... + +@overload +def apply_transform_to_pt(pts: Polygon3dType, T: AffineMatrix4x4, inv: bool = False) -> Polygon3dType: ... + +def apply_transform_to_pt(pts: Vector3d | Polygon3dType, T: AffineMatrix4x4, inv: bool = False): + """Apply homogeneous transformation matrix to points. + + Parameters + ---------- + pts : np.ndarray + Point coordinates to transform, shape (3,) or (3, N). + T : np.ndarray + 4x4 homogeneous transformation matrix. + inv : bool, default=False + If True, applies inverse of transformation. + + Returns + ------- + np.ndarray + Transformed point coordinates, shape (3,) or (3, N). + """ + if inv: + T = np.linalg.inv(T) + + if pts.ndim == 1: + return nib.affines.apply_affine(T, pts) + else: + return nib.affines.apply_affine(T, pts.T).T + + +def calc_mapping_to_standard_space( + orig: nibabelImage, + ac_coords_3d: Vector3d, + pc_coords_3d: Vector3d, + orig_fsaverage_vox2vox: AffineMatrix4x4, +) -> tuple[AffineMatrix4x4, Vector3d, Vector3d, Vector3d, Vector3d]: + """Get transformations to map image to standard space. + + Parameters + ---------- + orig : nibabelImage + Original image. + ac_coords_3d : np.ndarray + AC coordinates in 3D space. + pc_coords_3d : np.ndarray + PC coordinates in 3D space. + orig_fsaverage_vox2vox : AffineMatrix4x4 + Transformation matrix from original to fsaverage space. + + Returns + ------- + standardized_to_orig_vox2vox : AffineMatrix4x4 + The vox2vox transformation matrix from standard space to original space. + ac_coords_standardized : Vector3d + AC coordinates in standard space. + pc_coords_standardized : Vector3d + PC coordinates in standard space. + ac_coords_orig : Vector3d + AC coordinates in original space. + pc_coords_orig : Vector3d + PC coordinates in original space. + """ + image_center = np.array(orig.shape) / 2 + + # correct nodding + nod_correct_2d = correct_nodding(ac_coords_3d[1:3], pc_coords_3d[1:3]) + + # convert 2D nodding correction to 3D transformation matrix + nod_correct_3d: AffineMatrix4x4 = np.eye(4, dtype=float) + nod_correct_3d[1:3, 1:3] = nod_correct_2d[:2, :2] # Copy rotation part to y,z axes + # Copy translation part to y,z axes (usually no translation) + nod_correct_3d[1:3, 3] = nod_correct_2d[:2, 2] + + ac_coords_after_nodding: Vector3d = apply_transform_to_pt( + ac_coords_3d, nod_correct_3d, inv=False, + ) + pc_coords_after_nodding: Vector3d = apply_transform_to_pt( + pc_coords_3d, nod_correct_3d, inv=False, + ) + + ac_to_center_translation: AffineMatrix4x4 = np.eye(4, dtype=float) + ac_to_center_translation[:3, 3] = image_center - ac_coords_after_nodding + + # correct nodding + ac_coords_standardized: Vector3d = apply_transform_to_pt( + ac_coords_after_nodding, ac_to_center_translation, inv=False, + ) + pc_coords_standardized: Vector3d = apply_transform_to_pt( + pc_coords_after_nodding, ac_to_center_translation, inv=False, + ) + + standardized_to_orig_vox2vox: AffineMatrix4x4 = ( + np.linalg.inv(orig_fsaverage_vox2vox) + @ np.linalg.inv(nod_correct_3d) + @ np.linalg.inv(ac_to_center_translation) + ) + + # calculate ac & pc in space of mri input image + ac_coords_orig: Vector3d = apply_transform_to_pt( + ac_coords_standardized, standardized_to_orig_vox2vox, inv=False, + ) + pc_coords_orig: Vector3d = apply_transform_to_pt( + pc_coords_standardized, standardized_to_orig_vox2vox, inv=False, + ) + return standardized_to_orig_vox2vox, ac_coords_standardized, pc_coords_standardized, ac_coords_orig, pc_coords_orig + + +def apply_transform_to_volume( + orig_image: nibabelImage, + interp_vox2vox: AffineMatrix4x4, + save_vox2ras: AffineMatrix4x4 | None = None, + header: nib.freesurfer.mghformat.MGHHeader | None = None, + output_path: str | Path | None = None, + output_size: np.ndarray | None = None, + order: int = 1 +) -> Image3d[np.float_]: + """Apply transformation to a volume and save the result. + + Parameters + ---------- + orig_image : nibabelImage + Input volume. + interp_vox2vox : np.ndarray + Transformation matrix to apply to the data, this is from input-to-output space. + save_vox2ras : AffineMatrix4x4, optional + The vox2ras matrix of the output image, only relevant if output_path is given. + header : nibabelHeader, optional + Header for the output image, only relevant if output_path is given, if None will default to orig_image header. + output_path : str or Path, optional + If output_path is provided, saves the result under this path using the dtype of header (or orig_image). + output_size : np.ndarray, optional + Size of output volume, uses input size by default `None`. + order : int, default=1 + Order of interpolation. + + Returns + ------- + np.ndarray + Transformed volume data of shape `output_size` and type float. + + Notes + ----- + Uses `scipy.ndimage.affine_transform` for the transformation, and inverts vox2vox internally as required by + `affine_transform`. + """ + if output_size is None: + output_size = np.array(orig_image.shape) + if header is None: + header = orig_image.header + if save_vox2ras is None: + save_vox2ras = orig_image.affine @ interp_vox2vox + # transform / resample the volume with vox2vox, note this needs to be the inverse of input2output vox2vox! + # affine_transform definition is: input_coord = matrix @ output_coord + offset ( == MATRIX_HOM @ output_coord_hom) + # --> output_coord = inv(matrix) @ (input_coord - offset) ( == inv(MATRIX_HOM) @ input_coord_hom) + resampled = affine_transform( + orig_image.get_fdata(), + np.linalg.inv(interp_vox2vox), + output_shape=output_size, + order=order, + ) + if output_path is not None: + logger.info(f"Saving transformed volume to {output_path}") + resampled_typecast = resampled.astype((header if header else orig_image).get_data_dtype()) + nib.save(nib.MGHImage(resampled_typecast, save_vox2ras, header), output_path) + return resampled + + +def make_affine(simpleITKImage: sitk.Image) -> AffineMatrix4x4: + """Create an affine transformation matrix from a SimpleITK image. + + Parameters + ---------- + simpleITKImage : sitk.Image + Input SimpleITK image. + + Returns + ------- + np.ndarray + 4x4 affine transformation matrix in RAS coordinates. + + Notes + ----- + The function: + 1. Gets affine transform in LPS coordinates + 2. Converts to RAS coordinates to match nibabel + 3. Returns the final 4x4 transformation matrix + """ + # get affine transform in LPS + c = [simpleITKImage.TransformContinuousIndexToPhysicalPoint(p) for p in np.eye(4)[:, :3]] + c = np.array(c) + affine = np.append(np.append(c[0:3] - c[3:], c[3:], axis=0), np.eye(4)[3], axis=1) + affine = np.transpose(affine) + # convert to RAS to match nibabel + affine = np.matmul(np.diag([-1.0, -1.0, 1.0, 1.0]), affine) + return affine + + +@overload +def map_softlabels_to_orig( + cc_fn_softlabels: Image4d, + orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: None = None, + orig2midslice_vox2vox: None = None, + orig_space_segmentation_path: str | Path | None = None, +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... + + +@overload +def map_softlabels_to_orig( + cc_fn_softlabels: Image4d, + orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: Image2d, + orig2midslice_vox2vox: AffineMatrix4x4, + orig_space_segmentation_path: str | Path | None = None, +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: ... + + +def map_softlabels_to_orig( + cc_fn_softlabels: Image4d, + orig: nibabelImage, + orig2slab_vox2vox: AffineMatrix4x4, + cc_subseg_midslice: Image2d | None = None, + orig2midslice_vox2vox: AffineMatrix4x4 | None = None, + orig_space_segmentation_path: str | Path | None = None, +) -> np.ndarray[Shape3d, np.dtype[np.int_]]: + """Map soft labels back to original image space and apply post-processing. + + Parameters + ---------- + cc_fn_softlabels : np.ndarray + Soft label predictions of shape (H, W, D, C=3). + orig : nibabelImage + Original image. + orig2slab_vox2vox : AffineMatrix4x4 + The vox2vox transformation matrix from orig to the slab. + cc_subseg_midslice : np.ndarray, optional + Mask for subdividing regions of shape (H, D) (only paired with orig2midslice_vox2vox). + orig2midslice_vox2vox : AffineMatrix4x4, optional + The vox2vox transformation matrix from orig to the midslice (only paired with cc_subseg_midslice). + orig_space_segmentation_path : str or Path, optional + Path to save segmentation in original space. + + Returns + ------- + np.ndarray + Final segmentation in original image space. + + Notes + ----- + The function: + 1. Transforms background, cc, and fornix label channels separately. + 2. Transform CC subsegmentation from midslice to orig and paint into segmentation if `cc_subseg_midslice` is passed. + 4. Saves result to `orig_space_segmentation_path` if passed. + """ + # map softlabels to original image + def _map_softlabel_to_orig(data: Image3d, fill: int) -> Image3d: + # # Note: affine_transforms requires the inverse of the intended direction -> orig2slab + return affine_transform(data, orig2slab_vox2vox, output_shape=orig.shape, order=1, cval=fill) + + if cc_subseg_midslice is not None and orig2midslice_vox2vox is not None: + # map subdivision mask to orig space, this will also expand the labels into left-right direction + cc_subseg_orig_space_fut = thread_executor().submit( + affine_transform, + cc_subseg_midslice[None], + orig2midslice_vox2vox, # Note: affine_transforms requires the inverse of the intended direction + output_shape=orig.shape, + order=0, + mode="nearest", + ) + else: + cc_subseg_orig_space_fut = None + + _softlabels = np.moveaxis(cc_fn_softlabels, -1, 0) + softlabels_iter = thread_executor().map(_map_softlabel_to_orig, _softlabels, [1., 0., 0.]) + softlabels_orig_space = np.stack(list(softlabels_iter), axis=-1) + # map to freesurfer labels + seg_lut = np.asarray([0, CC_LABEL, FORNIX_LABEL]) + seg_orig_space = seg_lut[np.argmax(softlabels_orig_space, axis=-1)] + + if cc_subseg_orig_space_fut is not None: + # replace CC_LABEL by subsegmentation labels + seg_orig_space = np.where(seg_orig_space == CC_LABEL, cc_subseg_orig_space_fut.result(), seg_orig_space) + + if orig_space_segmentation_path is not None: + logger.info(f"Saving segmentation in original space to {orig_space_segmentation_path}") + nib.save( + nib.MGHImage(seg_orig_space, orig.affine, orig.header), + orig_space_segmentation_path, + ) + return seg_orig_space diff --git a/CorpusCallosum/utils/types.py b/CorpusCallosum/utils/types.py new file mode 100644 index 000000000..78ad3b81a --- /dev/null +++ b/CorpusCallosum/utils/types.py @@ -0,0 +1,74 @@ +from typing import Literal, TypedDict + +from numpy import dtype, float_, ndarray + +from FastSurferCNN.utils import ScalarType + +__all__ = [ + "CCMeasuresDict", + "ContourList", + "ContourThickness", + "Points2dType", + "Points3dType", + "Polygon2dType", + "Polygon3dType", + "SliceSelection", + "SubdivisionMethod", +] + +Polygon2dType = ndarray[tuple[Literal[2], int], dtype[ScalarType]] +Polygon3dType = ndarray[tuple[Literal[3], int], dtype[ScalarType]] +Points2dType = ndarray[tuple[int, Literal[2]], dtype[ScalarType]] +Points3dType = ndarray[tuple[int, Literal[3]], dtype[ScalarType]] +ContourList = list[type[Polygon2dType]] +ContourThickness = ndarray[tuple[Literal[3], int], dtype[ScalarType]] +SliceSelection = Literal["middle", "all"] | int +SubdivisionMethod = Literal["shape", "vertical", "angular", "eigenvector"] + +class CCMeasuresDict(TypedDict): + """TypedDict for corpus callosum measures. + + Attributes + ---------- + cc_index : float + Corpus callosum shape index. + circularity : float + Shape circularity measure. + areas : np.ndarray + Areas of subdivided regions. + midline_length : float + Length along the midline. + thickness : float + Array of thickness measurements. + curvature : float + Array of curvature measurements. + thickness_profile : np.ndarray of type float + Thickness measurements along the contour. + total_area : float + Total area of the CC. + total_perimeter : float + Total perimeter length. + split_contours : list of np.ndarray + Subdivided contour segments in AS-slice coordinates. + midline_equidistant : np.ndarray + Equidistant points along midline in AS-slice coordinates. + levelpaths : list of np.ndarray + Paths for thickness measurements in AS-slice coordinates. + slice_index : int + Index of the processed slice. + """ + cc_index: float + circularity: float + areas: ndarray + midline_length: float + thickness: float + curvature: float + thickness_profile: ndarray[tuple[int], dtype[float_]] + total_area: float + total_perimeter: float + split_contours: ContourList + midline_equidistant: ndarray + curvature_subsegments: ndarray + curvature_body: float + levelpaths: list[ndarray] + slice_index: int diff --git a/CorpusCallosum/utils/visualization.py b/CorpusCallosum/utils/visualization.py new file mode 100644 index 000000000..f40db0ecb --- /dev/null +++ b/CorpusCallosum/utils/visualization.py @@ -0,0 +1,239 @@ +# Copyright 2025 AI in Medical Imaging, German Center for Neurodegenerative Diseases(DZNE), Bonn +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np + +from CorpusCallosum.utils.types import ContourList, Polygon2dType +from FastSurferCNN.utils import AffineMatrix4x4, Image3d, Vector2d + + +def plot_standardized_space( + ax_row: list[plt.Axes], + vol: np.ndarray, + ac_coords: np.ndarray, + pc_coords: np.ndarray +) -> None: + """Plot standardized space visualization across three views. + + Parameters + ---------- + ax_row : list[plt.Axes] + Row of axes to plot on (should be length 3). + vol : np.ndarray + Volume data to visualize. + ac_coords : np.ndarray + AC coordinates in standardized space. + pc_coords : np.ndarray + PC coordinates in standardized space. + + Notes + ----- + Creates three views: + - Axial (top view) + - Sagittal (side view) + - Coronal (front view) + """ + ax_row[0].set_title("Standardized") + + for i, (a, b, _) in ((2, 1, "Axial"), (2, 0, "Sagittal"), (1, 0, "Coronal")): + ax_row[i].scatter(ac_coords[a], ac_coords[b], color="red", marker="x") + ax_row[i].scatter(pc_coords[a], pc_coords[b], color="blue", marker="x") + ax_row[i].imshow(vol[(slice(None),) * i + (vol.shape[i] // 2,)], cmap="gray") + + +def visualize_coordinate_spaces( + orig: "nib.Nifti1Image", + upright: np.ndarray, + standardized: np.ndarray, + ac_coords_orig: np.ndarray, + pc_coords_orig: np.ndarray, + ac_coords_3d: np.ndarray, + pc_coords_3d: np.ndarray, + ac_coords_standardized: np.ndarray, + pc_coords_standardized: np.ndarray, + output_plot_path: str | Path, +) -> None: + """Visualize the AC and PC coordinates in different coordinate spaces. + + Creates a figure showing the anterior and posterior commissure points + in three different coordinate spaces for testing/debugging. + + Parameters + ---------- + orig : nibabel.Nifti1Image + Original image volume. + upright : np.ndarray + Volume in fsaverage space. + standardized : np.ndarray + Volume in standardized space. + ac_coords_orig : np.ndarray + AC coordinates in original space. + pc_coords_orig : np.ndarray + PC coordinates in original space. + ac_coords_3d : np.ndarray + AC coordinates in fsaverage space. + pc_coords_3d : np.ndarray + PC coordinates in fsaverage space. + ac_coords_standardized : np.ndarray + AC coordinates in standardized space. + pc_coords_standardized : np.ndarray + PC coordinates in standardized space. + output_plot_path : str or Path + Directory to save visualization. + + Notes + ----- + Saves a visualization of the anterior (red) and posterior (blue) commisure in three different view: + 1. the orig image (orig), + 2. fs-average standardized image space, and + 3. standardized image space + as a single image named 'ac_pc_spaces.png' in `output_dir`. + """ + fig, ax = plt.subplots(3, 4) + ax = ax.T + + # Original space - using plot_standardized_space + plot_standardized_space(ax[0], np.asarray(orig.dataobj), ac_coords_orig, pc_coords_orig) + ax[0, 0].set_title("Orig") + + # Fsaverage space + plot_standardized_space(ax[1], upright, ac_coords_3d, pc_coords_3d) + ax[1, 0].set_title("Fsaverage") + + # Standardized space + plot_standardized_space(ax[2], standardized, ac_coords_standardized, pc_coords_standardized) + ax[2, 0].set_title("Standardized") + # Format all subplots + for a in ax.flatten(): + a.set_aspect("equal", adjustable="box") + a.axis("off") + + plt.savefig(output_plot_path, dpi=300, bbox_inches="tight") + plt.show() + plt.close() + + +def plot_contours( + slice_or_slab: Image3d, + split_contours: ContourList | None = None, + midline_equidistant: Polygon2dType | None = None, + levelpaths: list[Polygon2dType] | None = None, + output_path: str | Path | list[Path | str] | None = None, + ac_coords_vox: Vector2d | None = None, + pc_coords_vox: Vector2d | None = None, + vox2ras: AffineMatrix4x4 | None = None, + title: str = "", +) -> None: + """Creates a figure of the contours (shape) and the subdivisions of the corpus callosum. + + Parameters + ---------- + slice_or_slab : np.ndarray + Intensities of the current slice, midslice or midslab (will plot middle slice). + split_contours : list[np.ndarray], optional + List of contour arrays for each subdivision (ignore contours on None) in upright AS coordinates each with shape + (N, 2). + midline_equidistant : np.ndarray, optional + Midline points at equidistant spacing (ignore midline on None) in upright AS coordinates with shape (2, N). + levelpaths : list[np.ndarray], optional + List of level paths for visualization (ignore level paths on None) in upright AS coordinates each with shape + (2, N). + output_path : str or Path or list of Paths, optional + Path to save the plot (show and do not save on None). + ac_coords_vox : np.ndarray, optional + AC coordinates for visualization (ignore AC on None) in LIA voxel coordinates. + pc_coords_vox : np.ndarray, optional + PC coordinates for visualization (ignore PC on None) in LIA voxel coordinates. + vox2ras : AffineMatrix4x4, optional + Slice vox2ras transformation matrix. + title : str, default="" + Title for the plot. + + Notes + ----- + Creates a visualization of the corpus callosum contours and their subdivisions. + If output_path is provided, saves the plot to that location. + """ + from functools import partial + + from nibabel.affines import apply_affine + + if vox2ras is None and None in (split_contours, midline_equidistant, levelpaths): + raise ValueError("vox_size must be provided if split_contours, midline_equidistant, or levelpaths are given.") + + if output_path is not None: + matplotlib.use('Agg') # Use non-GUI backend + + # convert vox_size from LIA to AS + ras2vox = partial(apply_affine, np.linalg.inv(vox2ras)[1:, 1:]) + + # scale contour data by vox_size to convert from AS to AS-aligned voxel space + _split_contours = [] if split_contours is None else [ras2vox(sp.T).T for sp in split_contours] + _midline_equi = np.zeros((0, 2)) if midline_equidistant is None else ras2vox(midline_equidistant) + _levelpaths = [] if levelpaths is None else [ras2vox(lp) for lp in levelpaths] + + has_first_plot = not (len(_split_contours) == 0 and ac_coords_vox is None and pc_coords_vox is None) + num_plots = 1 + int(has_first_plot) + + fig, ax = plt.subplots(1, num_plots, sharex=True, sharey=True, figsize=(15, 10)) + + # NOTE: For all plots imshow shows y inverted + current_plot = 0 + + # This visualization uses voxel coordinates in fsaverage space... + if has_first_plot: + ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray") + ax[current_plot].set_title(title) + if _split_contours: + for this_contour in _split_contours: + ax[current_plot].fill(this_contour[1, :], this_contour[0, :], color="steelblue", alpha=0.25) + kwargs = {"color": "mediumblue", "linewidth": 0.7, "linestyle": "solid"} + ax[current_plot].plot(this_contour[1, :], this_contour[0, :], **kwargs) + if ac_coords_vox is not None: + ax[current_plot].scatter(ac_coords_vox[1], ac_coords_vox[0], color="red", marker="x") + if pc_coords_vox is not None: + ax[current_plot].scatter(pc_coords_vox[1], pc_coords_vox[0], color="blue", marker="x") + current_plot += int(has_first_plot) + + ax[current_plot].imshow(slice_or_slab[slice_or_slab.shape[0] // 2], cmap="gray") + for this_path in _levelpaths: + ax[current_plot].plot(this_path[:, 1], this_path[:, 0], color="brown", linewidth=0.8) + ax[current_plot].set_title("Midline & Levelpaths") + if _midline_equi.shape[0] > 0: + ax[current_plot].plot(_midline_equi[:, 1], _midline_equi[:, 0], color="red") + if _split_contours: + reference_contour = _split_contours[0] + ax[current_plot].plot(reference_contour[1, :], reference_contour[0, :], color="red", linewidth=0.5) + + padding = 30 + for a in ax.flatten(): + a.set_aspect("equal", adjustable="box") + a.axis("off") + if _split_contours: + reference_contour = _split_contours[0] + # get bounding box of contours + a.set_xlim(reference_contour[1, :].min() - padding, reference_contour[1, :].max() + padding) + a.set_ylim((reference_contour[0, :]).max() + padding, (reference_contour[0, :]).min() - padding) + + if output_path is None: + return plt.show() + for _output_path in (output_path if isinstance(output_path, (list, tuple)) else [output_path]): + Path(_output_path).parent.mkdir(parents=True, exist_ok=True) + fig.savefig(_output_path, dpi=300, bbox_inches="tight") + return None diff --git a/FastSurferCNN/download_checkpoints.py b/FastSurferCNN/download_checkpoints.py index 23f65febd..35492d79e 100644 --- a/FastSurferCNN/download_checkpoints.py +++ b/FastSurferCNN/download_checkpoints.py @@ -17,6 +17,7 @@ from CerebNet.utils.checkpoint import ( YAML_DEFAULT as CEREBNET_YAML, ) +from CorpusCallosum.utils.checkpoint import YAML_DEFAULT as CC_YAML from FastSurferCNN.utils import PLANES from FastSurferCNN.utils.checkpoint import ( YAML_DEFAULT as VINN_YAML, @@ -26,9 +27,7 @@ get_checkpoints, load_checkpoint_config_defaults, ) -from HypVINN.utils.checkpoint import ( - YAML_DEFAULT as HYPVINN_YAML, -) +from HypVINN.utils.checkpoint import YAML_DEFAULT as HYPVINN_YAML class ConfigCache: @@ -40,9 +39,12 @@ def cerebnet_url(self): def hypvinn_url(self): return load_checkpoint_config_defaults("url", filename=HYPVINN_YAML) + + def cc_url(self): + return load_checkpoint_config_defaults("url", filename=CC_YAML) def all_urls(self): - return self.vinn_url() + self.cerebnet_url() + self.hypvinn_url() + return self.vinn_url() + self.cerebnet_url() + self.hypvinn_url() + self.cc_url() defaults = ConfigCache() @@ -72,6 +74,12 @@ def make_parser(): action="store_true", help="Check and download CerebNet default checkpoints", ) + parser.add_argument( + "--cc", + default=False, + action="store_true", + help="Check and download Corpus Callosum default checkpoints", + ) parser.add_argument( "--hypvinn", @@ -99,14 +107,15 @@ def make_parser(): def main( - vinn: bool, - cerebnet: bool, - hypvinn: bool, - all: bool, - files: list[str], + vinn: bool = False, + cerebnet: bool = False, + hypvinn: bool = False, + cc: bool = False, + all: bool = False, + files: list[str] = (), url: str | None = None, ) -> int | str: - if not vinn and not files and not cerebnet and not hypvinn and not all: + if not vinn and not files and not cerebnet and not hypvinn and not cc and not all: return ("Specify either files to download or --vinn, --cerebnet, " "--hypvinn, or --all, see help -h.") @@ -141,6 +150,16 @@ def main( *(hypvinn_config[plane] for plane in PLANES), urls=defaults.hypvinn_url() if url is None else [url], ) + # Corpus Callosum checkpoints + if cc or all: + cc_config = load_checkpoint_config_defaults( + "checkpoint", + filename=CC_YAML, + ) + get_checkpoints( + *(cc_config[model] for model in cc_config.keys()), + urls=defaults.cc_url() if url is None else [url], + ) for fname in files: check_and_download_ckpts( fname, diff --git a/README.md b/README.md index eda2acd25..8ff862d8f 100644 --- a/README.md +++ b/README.md @@ -24,16 +24,20 @@ Modules (all run by default): - the core, outputs anatomical segmentation and cortical parcellation and statistics of 95 classes, mimics FreeSurfer’s DKTatlas. - requires a T1w image ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, experimental beyond that). - performs bias-field correction and calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). -2. `cereb:` [CerebNet](CerebNet/README.md) for cerebellum sub-segmentation (deactivate with `--no_cereb`) +2. `cc`: [CorpusCallosum](CorpusCallosum/README.md) for corpus callosum segmentation and shape analysis (deactivate with `--no_cc`) + - requires `asegdkt_segfile` (segmentation) and conformed mri (orig.mgz), outputs CC segmentation, thickness, and shape metrics. + - standardizes brain orientation based on AC/PC landmarks (orient_volume.lta). +3. `cereb:` [CerebNet](CerebNet/README.md) for cerebellum sub-segmentation (deactivate with `--no_cereb`) - requires `asegdkt_segfile`, outputs cerebellar sub-segmentation with detailed WM/GM delineation. - requires a T1w image ([notes on input images](#requirements-to-input-images)), which will be resampled to 1mm isotropic images (no native high-res support). - calculates volume statistics corrected for partial volume effects (skipped if `--no_biasfield` is passed). -3. `hypothal`: [HypVINN](HypVINN/README.md) for hypothalamus subsegmentation (deactivate with `--no_hypothal`) +4. `hypothal`: [HypVINN](HypVINN/README.md) for hypothalamus subsegmentation (deactivate with `--no_hypothal`) - outputs a hypothalamic subsegmentation including 3rd ventricle, c. mammilare, fornix and optic tracts. - a T1w image is highly recommended ([notes on input images](#requirements-to-input-images)), supports high-res (up to 0.7mm, but experimental beyond that). - allows the additional passing of a T2w image with `--t2 `, which will be registered to the T1w image (see `--reg_mode` option). - calculates volume statistics corrected for partial volume effects based on the T1w image (skipped if `--no_bias_field` is passed). + ### Surface reconstruction - approximately 60-90 minutes, `--surf_only` runs only [the surface part](recon_surf/README.md). - supports high-resolution images (up to 0.7mm, experimental beyond that). @@ -125,6 +129,8 @@ All the examples can be found here: [FASTSURFER_EXAMPLES](doc/overview/EXAMPLES. Modules output can be found here: [FastSurfer_Output_Files](doc/overview/OUTPUT_FILES.md) - [Segmentation module](doc/overview/OUTPUT_FILES.md#segmentation-module) - [Cerebnet module](doc/overview/OUTPUT_FILES.md#cerebnet-module) +- [HypVINN module](doc/overview/OUTPUT_FILES.md#hypvinn-module) +- [Corpus Callosum module](doc/overview/OUTPUT_FILES.md#corpus-callosum-module) - [Surface module](doc/overview/OUTPUT_FILES.md#surface-module) @@ -146,7 +152,7 @@ The default device is the GPU. The view-aggregation device can be switched to CP ## Expert usage Individual modules and the surface pipeline can be run independently of the full pipeline script documented in this documentation. -This is documented in READMEs in subfolders, for example: [whole brain segmentation only with FastSurferVINN](FastSurferCNN/README.md), [cerebellum sub-segmentation](CerebNet/README.md), [hypothalamic sub-segmentation](HypVINN/README.md) and [surface pipeline only (recon-surf)](recon_surf/README.md). +This is documented in READMEs in subfolders, for example: [whole brain segmentation only with FastSurferVINN](FastSurferCNN/README.md), [cerebellum sub-segmentation](CerebNet/README.md), [hypothalamic sub-segmentation](HypVINN/README.md), [corpus callosum analysis](CorpusCallosum/README.md) and [surface pipeline only (recon-surf)](recon_surf/README.md). Specifically, the segmentation modules feature options for optimized parallelization of batch processing. diff --git a/doc/api/CorpusCallosum.data.rst b/doc/api/CorpusCallosum.data.rst new file mode 100644 index 000000000..a89128e20 --- /dev/null +++ b/doc/api/CorpusCallosum.data.rst @@ -0,0 +1,11 @@ +CorpusCallosum.data +=================== + +.. currentmodule:: CorpusCallosum.data + +.. autosummary:: + :toctree: generated/ + + constants + fsaverage_cc_template + read_write diff --git a/doc/api/CorpusCallosum.localization.rst b/doc/api/CorpusCallosum.localization.rst new file mode 100644 index 000000000..9c6c3b400 --- /dev/null +++ b/doc/api/CorpusCallosum.localization.rst @@ -0,0 +1,9 @@ +CorpusCallosum.localization +============================= + +.. currentmodule:: CorpusCallosum.localization + +.. autosummary:: + :toctree: generated/ + + inference diff --git a/doc/api/CorpusCallosum.rst b/doc/api/CorpusCallosum.rst new file mode 100644 index 000000000..7d9152e5b --- /dev/null +++ b/doc/api/CorpusCallosum.rst @@ -0,0 +1,11 @@ +CorpusCallosum +============== + +.. currentmodule:: CorpusCallosum + +.. autosummary:: + :toctree: generated/ + + fastsurfer_cc + cc_visualization + paint_cc_into_pred diff --git a/doc/api/CorpusCallosum.segmentation.rst b/doc/api/CorpusCallosum.segmentation.rst new file mode 100644 index 000000000..0269688bf --- /dev/null +++ b/doc/api/CorpusCallosum.segmentation.rst @@ -0,0 +1,10 @@ +CorpusCallosum.segmentation +============================ + +.. currentmodule:: CorpusCallosum.segmentation + +.. autosummary:: + :toctree: generated/ + + inference + segmentation_postprocessing diff --git a/doc/api/CorpusCallosum.shape.rst b/doc/api/CorpusCallosum.shape.rst new file mode 100644 index 000000000..f4c059e3f --- /dev/null +++ b/doc/api/CorpusCallosum.shape.rst @@ -0,0 +1,15 @@ +CorpusCallosum.shape +==================== + +.. currentmodule:: CorpusCallosum.shape + +.. autosummary:: + :toctree: generated/ + + postprocessing + mesh + metrics + thickness + subsegment_contour + endpoint_heuristic + contour diff --git a/doc/api/CorpusCallosum.transforms.rst b/doc/api/CorpusCallosum.transforms.rst new file mode 100644 index 000000000..14756a92e --- /dev/null +++ b/doc/api/CorpusCallosum.transforms.rst @@ -0,0 +1,10 @@ +CorpusCallosum.transforms +=========================== + +.. currentmodule:: CorpusCallosum.transforms + +.. autosummary:: + :toctree: generated/ + + localization + segmentation diff --git a/doc/api/CorpusCallosum.utils.rst b/doc/api/CorpusCallosum.utils.rst new file mode 100644 index 000000000..33fe5e045 --- /dev/null +++ b/doc/api/CorpusCallosum.utils.rst @@ -0,0 +1,12 @@ +CorpusCallosum.utils +==================== + +.. currentmodule:: CorpusCallosum.utils + +.. autosummary:: + :toctree: generated/ + + checkpoint + mapping_helpers + types + visualization diff --git a/doc/api/index.rst b/doc/api/index.rst index 546cdf4fa..fd606a8ba 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -16,6 +16,13 @@ FastSurfer API CerebNet.datasets.rst CerebNet.models.rst CerebNet.utils.rst + CorpusCallosum.rst + CorpusCallosum.data.rst + CorpusCallosum.localization.rst + CorpusCallosum.segmentation.rst + CorpusCallosum.shape.rst + CorpusCallosum.transforms.rst + CorpusCallosum.utils.rst HypVINN.rst HypVINN.dataloader.rst HypVINN.models.rst diff --git a/doc/api/recon_surf.rst b/doc/api/recon_surf.rst index 4e19a65bb..0387d24ed 100644 --- a/doc/api/recon_surf.rst +++ b/doc/api/recon_surf.rst @@ -13,7 +13,6 @@ recon_surf fs_balabels map_surf_label N4_bias_correct - paint_cc_into_pred rewrite_oriented_surface rewrite_mc_surface rotate_sphere diff --git a/doc/conf.py b/doc/conf.py index faa0d8df0..0921c6b23 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,8 +12,7 @@ import os from pathlib import Path -# here i added the relative path because sphinx was not able -# to locate FastSurferCNN module directly for autosummary +# relative path so sphinx can locate the different modules directly for autosummary sys.path.append(os.path.dirname(__file__) + "/..") sys.path.append(os.path.dirname(__file__) + "/../recon_surf") sys.path.append(os.path.dirname(__file__) + "/sphinx_ext") @@ -227,7 +226,7 @@ def import_from_path(module_name, file_path): linkcode_resolve = LinkCodeResolver(gh_url, branch) -_re_script_dirs = "fastsurfercnn|cerebnet|recon_surf|hypvinn" +_re_script_dirs = "fastsurfercnn|cerebnet|recon_surf|hypvinn|corpuscallosum" _up = "^/\\.\\./" _end = "(\\.md)?(#.*)?$" diff --git a/doc/overview/FLAGS.md b/doc/overview/FLAGS.md index 3f06d74dd..735136fbc 100644 --- a/doc/overview/FLAGS.md +++ b/doc/overview/FLAGS.md @@ -6,7 +6,7 @@ The `*fastsurfer-flags*` will usually at least include the subject directory (`- ```bash ... --sd /output --sid test_subject --t1 /data/test_subject_t1.nii.gz --3T ``` -Additionally, you can use `--seg_only` or `--surf_only` to only run a part of the pipeline or `--no_biasfield`, `--no_cereb` and `--no_asegdkt` to switch off individual segmentation modules. +Additionally, you can use `--seg_only` or `--surf_only` to only run a part of the pipeline or `--no_biasfield`, `--no_cereb`, `--no_hypothal`, `--no_cc`, and `--no_asegdkt` to switch off individual segmentation modules. Here, we have also added the `--3T` flag, which tells FastSurfer to register against the 3T atlas which is only relevant for the ICV estimation (eTIV). In the following, we give an overview of the most important options. You can view a [full list of options](FLAGS.md#full-list-of-flags) with @@ -30,6 +30,8 @@ In the following, we give an overview of the most important options. You can vie * `--device`: Select device for neural network segmentation (_auto_, _cpu_, _cuda_, _cuda:_, _mps_), where cuda means Nvidia GPU, you can select which one e.g. "cuda:1". Default: "auto", check GPU and then CPU. "mps" is for native MAC installs to use the Apple silicon (M-chip) GPU. * `--asegdkt_segfile`: Name of the segmentation file, which includes the aparc+DKTatlas-aseg segmentations. Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/aparc.DKTatlas+aseg.deep.mgz * `--no_cereb`: Switch off the cerebellum sub-segmentation. +* `--no_hypothal`: Skip the hypothalamus segmentation. +* `--no_cc`: Skip the segmentation and analysis of the corpus callosum. * `--cereb_segfile`: Name of the cerebellum segmentation file. If not provided, this intermediate DL-based segmentation will not be stored, but only the merged segmentation will be stored (see --main_segfile ). Requires an ABSOLUTE Path! Default location: \$SUBJECTS_DIR/\$sid/mri/cerebellum.CerebNet.nii.gz * `--no_biasfield`: Deactivate the biasfield correction and calculation of partial volume-corrected statistics in the segmentation modules. * `--native_image` or `--keepgeom`: **Only supported for `--seg_only`**, segment in native image space (keep orientation, image size and voxel size of the input image), this also includes experimental support for anisotropic images (no extreme anisotropy). diff --git a/doc/overview/OUTPUT_FILES.md b/doc/overview/OUTPUT_FILES.md index 87416ed30..8d7ab7b2a 100644 --- a/doc/overview/OUTPUT_FILES.md +++ b/doc/overview/OUTPUT_FILES.md @@ -15,6 +15,31 @@ The segmentation module outputs the files shown in the table below. The two prim | scripts | deep-seg.log | asegdkt | logfile | | stats | aseg+DKT.stats | asegdkt | table of cortical and subcortical segmentation statistics | + +## Corpus Callosum module + +The Corpus Callosum module outputs the files in the table shown below. It creates detailed segmentations and shape analysis of the corpus callosum. + +| directory | filename | module | description | +|:----------------|--------------------------------|--------|--------------------------------------------------------------------------------------------------------------| +| mri | callosum_seg_upright.mgz | cc | corpus callosum segmentation in upright space | +| mri | callosum_seg_aseg_space.mgz | cc | corpus callosum segmentation in conformed image orientation | +| mri | callosum_seg_soft.mgz | cc | corpus callosum soft labels | +| mri | fornix_seg_soft.mgz | cc | fornix soft labels | +| mri | background_seg_soft.mgz | cc | background soft labels | +| mri/transforms | cc_up.lta | cc | transform from original to upright space | +| mri/transforms | orient_volume.lta | cc | transform to standardized space | +| stats | callosum.CC.midslice.json | cc | measurements from the middle sagittal slice (landmarks, area, thickness, etc.) | +| stats | callosum.CC.all_slices.json | cc | comprehensive per-slice analysis (only when using `--slice_selection all`) | +| qc_snapshots | callosum.png | cc | debug visualization of contours and thickness | +| qc_snapshots | callosum_thickness.png | cc | 3D thickness visualization (with `--slice_selection all`) | +| qc_snapshots | corpus_callosum.html | cc | interactive 3D mesh visualization (with `--slice_selection all`) | +| surf | callosum.surf | cc | FreeSurfer surface format (with `--slice_selection all`) | +| surf | callosum.thickness.w | cc | FreeSurfer overlay file containing thickness values (with `--slice_selection all`) | +| surf | callosum_mesh.vtk | cc | VTK format mesh file for 3D visualization (with `--slice_selection all`) | + + + ## Cerebnet module The cerebellum module outputs the files in the table shown below. Unless switched off by the `--no_cereb` argument, this module is automatically run whenever the segmentation module is run. It adds two files, an image with the sub-segmentation of the cerebellum and a text file with summary statistics. @@ -73,4 +98,4 @@ The primary output files are pial, white, and inflated surface files, the thickn | stats | lh.aparc.DKTatlas.mapped.stats, rh.aparc.DKTatlas.mapped.stats | surface | table of cortical parcellation statistics, mapped from ASEGDKT segmentation to the surface | | stats | lh.curv.stats, rh.curv.stats | surface | table of curvature statistics | | stats | wmparc.DKTatlas.mapped.stats | surface | table of white matter segmentation statistics | -| scripts | recon-all.log | surface | logfile | \ No newline at end of file +| scripts | recon-all.log | surface | logfile | diff --git a/doc/overview/index.rst b/doc/overview/index.rst index e41f65932..2fca45ff3 100644 --- a/doc/overview/index.rst +++ b/doc/overview/index.rst @@ -10,6 +10,7 @@ User Guide EXAMPLES.md FLAGS.md OUTPUT_FILES.md + modules/index docker SINGULARITY.md MACOS.md diff --git a/doc/overview/modules/CC.md b/doc/overview/modules/CC.md new file mode 100644 index 000000000..b08d5b884 --- /dev/null +++ b/doc/overview/modules/CC.md @@ -0,0 +1,127 @@ +# Corpus Callosum Pipeline + +A deep learning-based pipeline for automated segmentation, analysis, and shape analysis of the corpus callosum in brain MRI scans. +Also segments the fornix, localizes the anterior and posterior commissure (AC and PC) and standardizes the orientation of the brain. + +## Overview + +This pipeline combines localization and segmentation deep learning models to: +1. Detect AC (Anterior Commissure) and PC (Posterior Commissure) points +2. Extract and align midplane slices +3. Segment the corpus callosum +4. Perform advanced morphometry for corpus callosum, including subdivision, thickness analysis, and various shape metrics +5. Generate visualizations and measurements + +## Analysis Modes + +The pipeline supports different analysis modes that determine the type of template data generated. + +### 3D Analysis + +When running the main pipeline with `--slice_selection all` and `--save_template_dir`, a complete 3D template is generated: + +```bash +# Generate 3D template data +python3 fastsurfer_cc.py --sd /data/subjects --sid sub001 \ + --slice_selection all \ + --save_template_dir /data/templates/sub001 +``` + +This creates: +- `contour_.txt`: Multi-slice contour data for 3D reconstruction +- `thickness_values_.txt`: Thickness measurements across all slices +- `thickness_measurement_points_.txt`: 3D vertex indices for thickness measurements + +**Benefits:** +- Enables volumetric thickness analysis +- Supports advanced 3D visualizations with proper surface topology +- Creates FreeSurfer-compatible overlay files for integration with other tools + +For visualization instructions and outputs, see the [cc_visualization.py documentation](../../scripts/cc_visualization.rst). + +### 2D Analysis + +When using `--slice_selection middle` or a specific slice number with `--save_template_dir`: + +```bash +# Generate 2D template data (middle slice) +python3 fastsurfer_cc.py --sd /data/subjects --sid sub001 \ + --slice_selection middle \ + --save_template_dir /data/templates/sub001 +``` + +**Benefits:** +- Faster processing for single-slice analysis +- 2D visualization is most suitable for displaying downstream statistics +- Compatibility with classical corpus callosum studies + +For 2D visualization instructions and outputs, see the [cc_visualization.py documentation](../../scripts/cc_visualization.rst). + +### Choosing Analysis Mode + +**Use 3D Analysis (`--slice_selection all`) when:** +- You need complete volumetric analysis +- Surface-based visualization is required +- Integration with FreeSurfer workflows is needed +- Comprehensive thickness mapping across the entire corpus callosum is desired + +**Use 2D Analysis (`--slice_selection middle` or specific slice) when:** +- Traditional single-slice morphometry is sufficient +- Faster processing is preferred +- Focus is on mid-sagittal cross-sectional measurements +- Compatibility with classical corpus callosum studies is needed + +**Note:** The default behavior is `--slice_selection all` for comprehensive 3D analysis. Use `--slice_selection middle` to process only the middle slice for faster, traditional 2D analysis. + +## JSON Output Structure + +The pipeline generates two main JSON files with detailed measurements and analysis results: + +### `stats/callosum.CC.midslice.json` (Middle Slice Analysis) + +This file contains measurements from the middle sagittal slice and includes: + +**Shape Measurements (single values):** +- `total_area`: Total corpus callosum area (mm²) +- `total_perimeter`: Total perimeter length (mm) +- `circularity`: Shape circularity measure (4π × area / perimeter²) +- `cc_index`: Corpus callosum shape index (length/width ratio) +- `midline_length`: Length along the corpus callosum midline (mm) +- `curvature`: Average curve of the midline (degrees), measured by angle between it's sub-segements + +**Subdivisions** +- `areas`: Areas of CC using an improved Hofer-Frahm sub-division method (mm²). This gives more consistent sub-segemnts while preserving the original ratios. + +**Thickness Analysis:** +- `thickness`: Average corpus callosum thickness (mm) +- `thickness_profile`: Thickness profile (mm) of the corpus callosum slice (100 thickness values by default, listed from anterior to posterior CC ends) + + +**Volume Measurements (when multiple slices processed):** +- `cc_5mm_volume`: Total CC volume within 5mm slab using voxel counting (mm³) +- `cc_5mm_volume_pv_corrected`: Volume with partial volume correction using CC contours (mm³) + +**Anatomical Landmarks:** +- `ac_center`: Anterior commissure coordinates in original image space +- `pc_center`: Posterior commissure coordinates in original image space +- `ac_center_oriented_volume`: AC coordinates in standardized space (orient_volume.lta) +- `pc_center_oriented_volume`: PC coordinates in standardized space (orient_volume.lta) +- `ac_center_upright`: AC coordinates in upright space (cc_up.lta) +- `pc_center_upright`: PC coordinates in upright space (cc_up.lta) + +### `stats/callosum.CC.all_slices.json` (Multi-Slice Analysis) + +This file contains comprehensive per-slice analysis when using `--slice_selection all`: + +**Global Parameters:** +- `slices_in_segmentation`: Total number of slices in the segmentation volume +- `voxel_size`: Voxel dimensions [x, y, z] in mm +- `subdivision_method`: Method used for anatomical subdivision +- `num_thickness_points`: Number of points used for thickness estimation +- `subdivision_ratios`: Subdivision fractions used for regional analysis +- `contour_smoothing`: Gaussian sigma used for contour smoothing +- `slice_selection`: Slice selection mode used + +**Per-Slice Data (`slices` array):** + +Each slice entry contains the shape measurements, thickness analysis and sub-divisions as described above. diff --git a/doc/overview/modules/index.rst b/doc/overview/modules/index.rst new file mode 100644 index 000000000..17b1cc454 --- /dev/null +++ b/doc/overview/modules/index.rst @@ -0,0 +1,9 @@ +Modules +======= + +FastSurfer includes several specialized deep learning modules that can be run independently or as part of the main pipeline. These modules provide detailed sub-segmentations and analyses for specific brain regions. + +.. toctree:: + :maxdepth: 2 + + CC diff --git a/doc/scripts/advanced.rst b/doc/scripts/advanced.rst index 82551a7ca..d18d755dd 100644 --- a/doc/scripts/advanced.rst +++ b/doc/scripts/advanced.rst @@ -7,6 +7,8 @@ Advanced scripts fastsurfercnn cerebnet hypvinn + fastsurfer_cc + cc_visualization recon_surf segstats long_compat_segmentHA diff --git a/doc/scripts/cc_visualization.rst b/doc/scripts/cc_visualization.rst new file mode 100644 index 000000000..e9e4c136f --- /dev/null +++ b/doc/scripts/cc_visualization.rst @@ -0,0 +1,53 @@ +CorpusCallosum: cc_visualization.py +=================================== + +.. argparse:: + :module: CorpusCallosum.cc_visualization + :func: make_parser + :prog: cc_visualization.py + +Usage Examples +-------------- + +3D Visualization +~~~~~~~~~~~~~~~~ + +To visualize a 3D template generated by ``fastsurfer_cc.py`` (using ``--slice_selection all --save_template_dir ...``), +point the script to the exported template directory: + +.. code-block:: bash + + python3 cc_visualization.py \ + --template_dir /data/templates/sub001/cc_template \ + --output_dir /data/visualizations/sub001 + +2D Visualization +~~~~~~~~~~~~~~~~ + +To visualize a 2D template (using ``--slice_selection middle --save_template_dir ...``): + +.. code-block:: bash + + python3 cc_visualization.py \ + --template_dir /data/templates/sub001/cc_template \ + --output_dir /data/visualizations/sub001 \ + --twoD + +.. note:: + + The ``--template_dir`` is the required way to load the templates + produced by ``fastsurfer_cc.py``. + +Outputs +------- + +3D Mode Outputs (default): + - ``cc_mesh.vtk``: VTK format mesh file for 3D visualization + - ``cc_mesh.fssurf``: FreeSurfer surface format + - ``cc_mesh_overlay.curv``: FreeSurfer overlay file with thickness values + - ``cc_mesh.html``: Interactive 3D mesh visualization + - ``cc_mesh_snap.png``: Snapshot image of the 3D mesh + - ``midslice_2d.png``: 2D visualization of the middle slice + +2D Mode Outputs (when ``--twoD`` is specified): + - ``cc_thickness_2d.png``: 2D contour visualization with thickness colormap diff --git a/doc/scripts/fastsurfer_cc.rst b/doc/scripts/fastsurfer_cc.rst new file mode 100644 index 000000000..d2f5fcbcd --- /dev/null +++ b/doc/scripts/fastsurfer_cc.rst @@ -0,0 +1,21 @@ +CorpusCallosum: fastsurfer_cc.py +================================ + +.. note:: + We recommend to run FastSurfer-CC with the standard `run_fastsurfer.sh` interfaces ! + + +.. + [Note] To tell sphinx where in the documentation CorpusCallosum/README.md can be linked to, it needs to be included somewhere + +.. include:: ../../CorpusCallosum/README.md + :parser: fix_links.parser + :start-line: 1 + +.. argparse:: + :module: CorpusCallosum.fastsurfer_cc + :func: make_parser + :prog: fastsurfer_cc.py + +.. include:: ../overview/modules/CC.md + :parser: fix_links.parser diff --git a/env/fastsurfer.yml b/env/fastsurfer.yml index e3b991bf4..d8bc1323e 100644 --- a/env/fastsurfer.yml +++ b/env/fastsurfer.yml @@ -5,8 +5,9 @@ channels: dependencies: - h5py==3.12.1 -- lapy==1.2.0 +- lapy==1.5.0 - matplotlib==3.10.1 +- monai==1.4.0 - nibabel==5.3.2 - numpy==1.26.4 - pandas==2.2.3 @@ -29,3 +30,6 @@ dependencies: - torch==2.6.0+cu126 - torchio==0.20.4 - torchvision==0.21.0+cu126 + - meshpy>=2025.1.1 + - pyrr>=0.10.3 + - whippersnappy>=1.3.1 diff --git a/pyproject.toml b/pyproject.toml index 727065f8e..52b1576cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ classifiers = [ ] dependencies = [ 'h5py>=3.7', - 'lapy>=1.1.0', + "lapy>=1.5.0", 'matplotlib>=3.7.1', 'nibabel>=5.1.0', 'numpy>=1.25,<2', @@ -50,6 +50,9 @@ dependencies = [ 'torchvision>=0.15.2', 'tqdm>=4.65', 'yacs>=0.1.8', + 'monai>=1.4.0', + 'meshpy>=2025.1.1', + 'pyrr>=0.10.3', 'pip>=25.0', ] diff --git a/recon_surf/align_points.py b/recon_surf/align_points.py index e549466d4..c69446ee4 100755 --- a/recon_surf/align_points.py +++ b/recon_surf/align_points.py @@ -127,8 +127,7 @@ def find_rotation(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: return R - -def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: +def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray, verbose: bool = False) -> npt.NDArray[float]: """ Find rigid transformation matrix between two point sets. @@ -138,10 +137,12 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: Source points. p_dst : npt.NDArray Destination points. + verbose : bool, optional + Whether to print debug information, by default False. Returns ------- - T + np.ndarray Homogeneous transformation matrix. """ if p_mov.shape != p_dst.shape: @@ -159,16 +160,17 @@ def find_rigid(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: t = centroid_dst.T - np.dot(R, centroid_mov.T) # homogeneous transformation m = p_mov.shape[1] - T = np.identity(m + 1) - T[:m, :m] = R - T[:m, m] = t + rigid_transform = np.identity(m + 1, dtype=float) + rigid_transform[:m, :m] = R + rigid_transform[:m, m] = t # compute disteances - dd = p_mov - p_dst - print(f"Initial avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") - dd = (np.transpose(R @ np.transpose(p_mov)) + t) - p_dst - print(f"Final avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") + if verbose: + dd = p_mov - p_dst + print(f"Initial avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") + dd = (np.transpose(R @ np.transpose(p_mov)) + t) - p_dst + print(f"Final avg SSD: {np.sum(dd * dd) / p_mov.shape[0]}") # return T, R, t - return T + return rigid_transform def find_affine(p_mov: npt.NDArray, p_dst: npt.NDArray) -> np.ndarray: """ diff --git a/recon_surf/paint_cc_into_pred.py b/recon_surf/paint_cc_into_pred.py deleted file mode 100644 index ec649a869..000000000 --- a/recon_surf/paint_cc_into_pred.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# IMPORTS - -import argparse -import sys - -import nibabel as nib -import numpy as np -from numpy import typing as npt - -HELPTEXT = """ -Script to add corpus callosum segmentation (CC, FreeSurfer IDs 251-255) to -deep-learning prediction (e.g. aparc.DKTatlas+aseg.deep.mgz). - - -USAGE: -paint_cc_into_pred -in_cc -in_pred -out - - -Dependencies: - Python 3.8+ - - Nibabel to read and write FreeSurfer data - http://nipy.org/nibabel/ - -Original Author: Leonie Henschel -Date: Jul-10-2020 - -""" - - -def argument_parse(): - """ - Create a command line interface and return command line options. - - Returns - ------- - options : argparse.Namespace - Namespace object holding options. - """ - parser = argparse.ArgumentParser(usage=HELPTEXT) - parser.add_argument( - "--input_cc", - "-in_cc", - dest="input_cc", - help="path to input segmentation with Corpus Callosum (IDs 251-255 in FreeSurfer space)", - ) - parser.add_argument( - "--input_pred", - "-in_pred", - dest="input_pred", - help="path to input segmentation Corpus Callosum should be added to.", - ) - parser.add_argument( - "--output", - "-out", - dest="output", - help="path to output (input segmentation + added CC)", - ) - - args = parser.parse_args() - - if args.input_cc is None or args.input_pred is None or args.output is None: - sys.exit("ERROR: Please specify input and output segmentations") - - return args - - -def paint_in_cc(pred: npt.ArrayLike, aseg_cc: npt.ArrayLike) -> npt.ArrayLike: - """ - Paint corpus callosum segmentation into aseg+dkt segmentation map. - - Note, that this function modifies the original array and does not create a copy. - - Parameters - ---------- - asegdkt : npt.ArrayLike - Deep-learning segmentation map. - aseg_cc : npt.ArrayLike - Aseg segmentation with CC. - - Returns - ------- - asegdkt - Segmentation map with added CC. - """ - cc_mask = (aseg_cc >= 251) & (aseg_cc <= 255) - pred[cc_mask] = aseg_cc[cc_mask] - return pred - - -if __name__ == "__main__": - # Command Line options are error checking done here - options = argument_parse() - - print(f"Reading inputs: {options.input_cc} {options.input_pred}...") - aseg_image = np.asanyarray(nib.load(options.input_cc).dataobj) - prediction = nib.load(options.input_pred) - pred_with_cc = paint_in_cc(np.asanyarray(prediction.dataobj), aseg_image) - - print(f"Writing segmentation with corpus callosum to: {options.output}") - pred_with_cc_fin = nib.MGHImage(pred_with_cc, prediction.affine, prediction.header) - pred_with_cc_fin.to_filename(options.output) - - sys.exit(0) - - -# TODO: Rename the file (paint_cc_into_asegdkt or similar) and functions. diff --git a/recon_surf/recon-surf.sh b/recon_surf/recon-surf.sh index 648064710..5727c8505 100755 --- a/recon_surf/recon-surf.sh +++ b/recon_surf/recon-surf.sh @@ -619,21 +619,52 @@ fi # ============================= CC SEGMENTATION ============================================ -{ - echo " " - echo "============ Creating and adding CC Segmentation ============" - echo " " -} | tee -a "$LF" -# create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz -# Note: if original input segmentation already contains CC, this will exit with ERROR -# in the future maybe check and skip this step (and next) -cmd="mri_cc -aseg $aseg_nocc -o aseg.auto.mgz -lta $mdir/transforms/cc_up.lta $subject" -RunIt "$cmd" "$LF" -# add CC into aparc.DKTatlas+aseg.deep (not sure if this is really needed) -cmd="$python ${binpath}paint_cc_into_pred.py -in_cc $mdir/aseg.auto.mgz -in_pred $asegdkt_segfile -out $mdir/aparc.DKTatlas+aseg.deep.withCC.mgz" -RunIt "$cmd" "$LF" - - +# here, we are only generating the "necessary" files for the pipeline to recon-surf pipeline to +# complete, people should use the seg pipeline to get extended results. +callosum_seg="callosum_seg_aseg_space.mgz" +callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")" +aseg_auto="aseg.auto.mgz" +CorpusCallosumDir="$FASTSURFER_HOME/CorpusCallosum" +updated_cc_seg=0 +if [[ ! -e "$mdir/$aseg_auto" ]] || [[ ! -e "$mdir/$callosum_seg" ]] || [[ "$edits" == 1 ]] +then + { + echo " " + echo "============ Creating and adding CC Segmentation ============" + echo " " + } | tee -a "$LF" +fi +# here, in edits mode we also check, if the corpus callosum should be updated based on an updated aseg.nocc +if [[ ! -e "$mdir/$callosum_seg" ]] || \ + { [[ "$edits" == 1 ]] && [[ "$(date -r "$mdir/$aseg_nocc" "+%s")" -gt "$(date -r "$mdir/$callosum_seg" "+%s")" ]] ; } +then + { + echo "Segmenting the corpus callosum, so mri/$aseg_nocc exists. If you are interested in detailed" + echo " and extended analysis and statistics of the Corpus Callosum, use the corpus callosum pipeline" + echo " of the segmentation pipeline (in run_fastsurfer.sh, i.e. run without --no_cc)." + } + updated_cc_seg=1 + # create aseg.auto including corpus callosum segmentation and 46 sec, requires norm.mgz + # Note: if original input segmentation already contains CC, this will exit with ERROR + # in the future maybe check and skip this step (and next) + cmda=($python "$CorpusCallosumDir/fastsurfer_cc.py" --sd "$SUBJECTS_DIR" --sid "$subject" + "--aseg_name" "$mdir/$aseg_nocc" "--segmentation_in_orig" "$mdir/$callosum_seg" + --threads "$threads" + # qc_snapshots are only defined by the seg_only pipeline + # limit the processing things to do here + --slice_selection "middle" --cc_measures "none" --cc_mid_measures "none" --surf "none" + --thickness_overlay "none") + run_it "$LF" "${cmda[@]}" +fi +# do not move below statement up, fastsurfer_cc.py uses the $callosum_seg variable +if [[ "$edits" == 1 ]] && [[ -e "$mdir/$callosum_seg_manedit" ]] ; then callosum_seg="$callosum_seg_manedit" ; fi +cmd_paint_cc_into_pred=($python "$CorpusCallosumDir/paint_cc_into_pred.py" -in_cc "$mdir/$callosum_seg" -in_pred) +if [[ ! -e "$mdir/$aseg_auto" ]] || [[ "$updated_cc_seg" == 1 ]] +then + # add CC into aseg.auto.mgz as mri_cc did before. Not sure where this is used. + cmda=("${cmd_paint_cc_into_pred[@]}" "$mdir/$aseg_nocc" "-out" "$mdir/$aseg_auto") + run_it "$LF" "${cmda[@]}" +fi # ============================= FILLED ===================================================== { diff --git a/requirements.mac.txt b/requirements.mac.txt index 95af69a75..fc713c6af 100644 --- a/requirements.mac.txt +++ b/requirements.mac.txt @@ -1,5 +1,5 @@ h5py>=3.7 -lapy>=1.0.1 +lapy>=1.5.0 matplotlib>=3.7.1 nibabel>=5.1.0 numpy>=1.25,<2 @@ -16,4 +16,8 @@ torchio>=0.18.83 torchvision>=0.15.2 tqdm>=4.65 yacs>=0.1.8 - +monai>=1.4.0 +meshpy>=2025.1.1 +pyrr>=0.10.3 +whippersnappy>=1.3.1 +pip>=25.0 \ No newline at end of file diff --git a/run_fastsurfer.sh b/run_fastsurfer.sh index 83505e912..39d5e7ac2 100755 --- a/run_fastsurfer.sh +++ b/run_fastsurfer.sh @@ -32,6 +32,7 @@ fastsurfercnndir="$FASTSURFER_HOME/FastSurferCNN" cerebnetdir="$FASTSURFER_HOME/CerebNet" hypvinndir="$FASTSURFER_HOME/HypVINN" reconsurfdir="$FASTSURFER_HOME/recon_surf" +CorpusCallosumDir="$FASTSURFER_HOME/CorpusCallosum" # Regular flags defaults subject="" @@ -49,6 +50,7 @@ hypo_segfile="" hypo_statsfile="" hypvinn_flags=() hypvinn_regmode="coreg" +cc_flags=() conformed_name="" conformed_name_t2="" norm_name="" @@ -70,6 +72,7 @@ native_image="false" run_asegdkt_module="1" run_cereb_module="1" run_hypvinn_module="1" +run_cc_module="1" threads_seg="1" threads_surf="1" # python3.10 -s excludes user-directory package inclusion @@ -213,6 +216,11 @@ SEGMENTATION PIPELINE: --no_biasfield Deactivate the calculation of partial volume-corrected statistics. + CORPUS CALLOSUM MODULE: + --no_cc Skip the segmentation and analysis of the corpus callosum. + --qc_snap Create QC snapshots in \$SUBJECTS_DIR/\$sid/qc_snapshots + to simplify the QC process. + HYPOTHALAMUS MODULE (HypVINN): --no_hypothal Skip the hypothalamus segmentation. --no_biasfield This option implies --no_hypothal, as the hypothalamus @@ -458,6 +466,12 @@ case $key in --mask_name) mask_name="$1" ; warn_seg_only+=("$key" "$1") ; warn_base+=("$key" "$1") ; shift ;; --merged_segfile) merged_segfile="$1" ; shift ;; + # corupus callosum module options + #============================================================= + --no_cc) run_cc_module="0" ;; + # TODO: remove this dev flag + --upright) cc_flags+=("--upright_volume" "mri/upright.mgz") ;; + # cereb module options #============================================================= --no_cereb) run_cereb_module="0" ;; @@ -480,7 +494,11 @@ case $key in ;; # several options that set a variable - --qc_snap) hypvinn_flags+=(--qc_snap) ;; + --qc_snap) + hypvinn_flags+=(--qc_snap) ; + cc_flags+=(--qc_image "qc_snapshots/callosum.png" --thickness_image "qc_snapshots/callosum.thickness.png" + --cc_html "qc_snapshots/corpus_callosum.html") + ;; ############################################################## # surf-pipeline options @@ -588,6 +606,8 @@ fi if [[ -z "$merged_segfile" ]] ; then merged_segfile="$subject_dir/mri/fastsurfer.merged.mgz" ; fi if [[ -z "$asegdkt_segfile" ]] ; then asegdkt_segfile="$subject_dir/mri/aparc.DKTatlas+aseg.deep.mgz" ; fi if [[ -z "$aseg_segfile" ]] ; then aseg_segfile="$subject_dir/mri/aseg.auto_noCCseg.mgz"; fi +if [[ -z "$aseg_auto_segfile" ]] ; then aseg_auto_segfile="$subject_dir/mri/aseg.auto.mgz"; fi +if [[ -z "$callosum_seg" ]] ; then callosum_seg="$subject_dir/mri/callosum.CC.orig.mgz"; fi if [[ -z "$asegdkt_statsfile" ]] ; then asegdkt_statsfile="$subject_dir/stats/aseg+DKT.stats" ; fi if [[ -z "$asegdkt_vinn_statsfile" ]] ; then asegdkt_vinn_statsfile="$subject_dir/stats/aseg+DKT.VINN.stats" ; fi if [[ -z "$aseg_vinn_statsfile" ]] ; then aseg_vinn_statsfile="$subject_dir/stats/aseg.VINN.stats" ; fi @@ -708,6 +728,18 @@ then fi fi +if [[ "$run_seg_pipeline" == "1" ]] && { [[ "$run_asegdkt_module" == "0" ]] && [[ "$run_cc_module" == "1" ]]; } +then + if [[ ! -f "$asegdkt_segfile" ]] + then + echo "ERROR: To run the corpus callosum module but no asegdkt, the aseg segmentation must already exist." + echo " You passed --no_asegdkt but the asegdkt segmentation ($asegdkt_segfile) could not be found." + echo " If the segmentation is not saved in the default location ($asegdkt_segfile_default)," + echo " specify the absolute path and name via --asegdkt_segfile" + exit 1 + fi +fi + if [[ "$run_surf_pipeline" == "1" ]] && [[ "$native_image" != "false" ]] then echo "ERROR: The surface pipeline is not compatible with the options --native_image or " @@ -1078,6 +1110,88 @@ then fi fi + if [[ "$run_cc_module" ]] + then + # ============================= CC SEGMENTATION ============================================ + + # generate file names of for the analysis + asegdkt_withcc_segfile="$(add_file_suffix "$asegdkt_segfile" "withCC")" + asegdkt_withcc_vinn_statsfile="$(add_file_suffix "$asegdkt_vinn_statsfile" "withCC")" + aseg_auto_statsfile="$(dirname "$aseg_vinn_statsfile")/aseg.auto.mgz" + # note: callosum manedit currently only affects inpainting and not internal FastSurferCC processing (surfaces etc) + callosum_seg_manedit="$(add_file_suffix "$callosum_seg" "manedit")" + # generate callosum segmentation, mesh, shape and downstream measure files + cmd=($python "$CorpusCallosumDir/fastsurfer_cc.py" --sd "$sd" --sid "$subject" --threads "$threads_seg" + "--aseg_name" "$asegdkt_segfile" "--segmentation_in_orig" "$callosum_seg" "${cc_flags[@]}") + { + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: FastSurferCC corpus callosum analysis failed!" ; exit 1 ; fi + if [[ "$edits" == 1 ]] && [[ -f "$callosum_seg_manedit" ]] ; then callosum_seg="$callosum_seg_manedit" ; fi + + # add CC into aparc.DKTatlas+aseg.deep.mgz and aseg.auto.mgz as mri_cc did before. + cmd=($python "$CorpusCallosumDir/paint_cc_into_pred.py" -in_cc "$callosum_seg" -in_pred "$asegdkt_segfile" + "-out" "$asegdkt_withcc_segfile" "-aseg" "$aseg_auto_segfile") + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: asegdkt cc inpainting failed!" ; exit 1 ; fi + + if [[ "$run_biasfield" == 1 ]] + then + cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$asegdkt_withcc_segfile" --normfile "$norm_name" + --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" --sd "${sd}" --sid "${subject}" + --ids 2 4 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 43 44 46 47 49 50 51 52 53 + 54 58 60 63 77 251 252 253 254 255 + 1002 1003 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 + 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1034 1035 + 2002 2003 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 + 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2034 2035 + --threads "$threads_seg" --empty --excludeid 0 + --segstatsfile "$asegdkt_withcc_vinn_statsfile" + measures + # the following measures are unaffected by CC and do not need to be recomputed + --import SubCortGray Mask + ) + if [[ "$run_talairach_registration" == "true" ]] + then + cmd+=("EstimatedTotalIntraCranialVol" "BrainSegVol-to-eTIV" "MaskVol-to-eTIV") + fi + cmd+=(--file "$asegdkt_vinn_statsfile" + # recompute the measures changes coming from CC inpainting (only SubCortGray does not change) + --compute BrainSeg BrainSegNotVent SupraTentorial SupraTentorialNotVent + rhCerebralWhiteMatter lhCerebralWhiteMatter CerebralWhiteMatter + ) + echo_quoted "${cmd[@]}" + "${cmd[@]}" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then + echo "ERROR: asegdkt statsfile ($asegdkt_withcc_segfile) generation failed!" ; exit 1 + # this will only terminate the subshell + fi + fi + } 2>&1 | tee -a "$seg_log" + code="${PIPESTATUS[0]}" + if [[ "$code" != 0 ]]; then exit 1; fi # forward subshell exit to main script + + if [[ "$run_biasfield" == 1 ]] + then + { + cmd=($python "${fastsurfercnndir}/segstats.py" --segfile "$aseg_auto_segfile" --normfile "$norm_name" + --lut "$fastsurfercnndir/config/FreeSurferColorLUT.txt" --sd "${sd}" --sid "${subject}" + --threads "$threads_seg" --empty --excludeid 0 + --ids 2 4 3 5 7 8 10 11 12 13 14 15 16 17 18 24 26 28 31 41 42 43 44 46 47 49 50 51 52 53 54 58 60 63 77 + 251 252 253 254 255 + --segstatsfile "$aseg_auto_statsfile" + measures --import "all" --file "$asegdkt_withcc_vinn_statsfile" + ) + echo_quoted "${cmd[@]}" + "${cmd[@]}" 2>&1 + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then echo "ERROR: aseg statsfile ($aseg_auto_segfile) failed!" ; exit 1 ; fi + } | tee -a "$seg_log" + if [[ "${PIPESTATUS[0]}" != 0 ]] ; then exit 1; fi # forward subshell exit to main script + + fi + fi + if [[ "$run_cereb_module" == "1" ]] then if [[ "$run_biasfield" == "1" ]] diff --git a/tools/export_pip-r.sh b/tools/export_pip-r.sh old mode 100644 new mode 100755