diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 3148f635..9cb873bb 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -54,7 +54,7 @@ def init_slide_encoder_( selected_encoder: Encoder = Gigapath() - case EncoderName.CHIEF: + case EncoderName.CHIEF_CTRANSPATH: from stamp.encoding.encoder.chief import CHIEF selected_encoder: Encoder = CHIEF() @@ -140,7 +140,7 @@ def init_patient_encoder_( selected_encoder: Encoder = Gigapath() - case EncoderName.CHIEF: + case EncoderName.CHIEF_CTRANSPATH: from stamp.encoding.encoder.chief import CHIEF selected_encoder: Encoder = CHIEF() diff --git a/src/stamp/encoding/config.py b/src/stamp/encoding/config.py index e743fcfd..1a2bcba7 100644 --- a/src/stamp/encoding/config.py +++ b/src/stamp/encoding/config.py @@ -9,7 +9,7 @@ class EncoderName(StrEnum): COBRA = "cobra" EAGLE = "eagle" - CHIEF = "chief" + CHIEF_CTRANSPATH = "chief" TITAN = "titan" GIGAPATH = "gigapath" MADELEINE = "madeleine" diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index d9035f7a..0a3c7c68 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -189,7 +189,8 @@ def _read_h5( raise ValueError( f"Feature file does not have extractor's name in the metadata: {os.path.basename(h5_path)}" ) - return feats, coords, extractor + + return feats, coords, _resolve_extractor_name(extractor) def _save_features_( self, output_path: Path, feats: np.ndarray, feat_type: str @@ -215,3 +216,35 @@ def _save_features_( Path(tmp_h5_file.name).rename(output_path) _logger.debug(f"saved features to {output_path}") + + +def _resolve_extractor_name(raw: str) -> ExtractorName: + """ + Resolve an extractor string to a valid ExtractorName. + + Handles: + - exact matches ('gigapath', 'virchow-full') + - versioned strings like 'gigapath-ae23d', 'virchow-full-2025abc' + Raises ValueError if the base name is not recognized. + """ + if not raw: + raise ValueError("Empty extractor string") + + name = str(raw).strip().lower() + + # Exact match + for e in ExtractorName: + if name == e.value.lower(): + return e + + # Versioned form: '-something' + for e in ExtractorName: + if name.startswith(e.value.lower() + "-"): + return e + + # Otherwise fail + raise ValueError( + f"Unknown extractor '{raw}'. " + f"Expected one of {[e.value for e in ExtractorName]} " + f"or a versioned variant like '-'." + ) diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index eaab9750..2ad4b91b 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -113,7 +113,7 @@ def __init__(self) -> None: model.load_state_dict(chief, strict=True) super().__init__( model=model, - identifier=EncoderName.CHIEF, + identifier=EncoderName.CHIEF_CTRANSPATH, precision=torch.float32, required_extractors=[ ExtractorName.CHIEF_CTRANSPATH, @@ -178,7 +178,16 @@ def encode_patients_( for _, row in group.iterrows(): slide_filename = row[filename_label] h5_path = os.path.join(feat_dir, slide_filename) - feats, _ = self._validate_and_read_features(h5_path=h5_path) + # Skip if not an .h5 file + if not h5_path.endswith(".h5"): + tqdm.write(f"Skipping {slide_filename} (not an .h5 file)") + continue + + try: + feats, coords = self._validate_and_read_features(h5_path=h5_path) + except (FileNotFoundError, ValueError, OSError) as e: + tqdm.write(f"Skipping {slide_filename}: {e}") + continue feats_list.append(feats) if not feats_list: diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index e2fb0ebb..9cb3f6f5 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -129,7 +129,16 @@ def encode_patients_( slide_filename = row[filename_label] h5_path = os.path.join(feat_dir, slide_filename) - feats, coords = self._validate_and_read_features(h5_path=h5_path) + # Skip if not an .h5 file + if not h5_path.endswith(".h5"): + tqdm.write(f"Skipping {slide_filename} (not an .h5 file)") + continue + + try: + feats, coords = self._validate_and_read_features(h5_path=h5_path) + except (FileNotFoundError, ValueError, OSError) as e: + tqdm.write(f"Skipping {slide_filename}: {e}") + continue # Get the mpp of one slide and check that the rest have the same if slides_mpp < 0: diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 41dd19f1..1012d98f 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -134,7 +134,16 @@ def encode_patients_( slide_filename = row[filename_label] h5_path = os.path.join(feat_dir, slide_filename) - feats, coords = self._validate_and_read_features(h5_path=h5_path) + # Skip if not an .h5 file + if not h5_path.endswith(".h5"): + tqdm.write(f"Skipping {slide_filename} (not an .h5 file)") + continue + + try: + feats, coords = self._validate_and_read_features(h5_path=h5_path) + except (FileNotFoundError, ValueError, OSError) as e: + tqdm.write(f"Skipping {slide_filename}: {e}") + continue # Get the mpp of one slide and check that the rest have the same if slides_mpp < 0: diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index 23a8e0c3..f20c87ae 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -240,15 +240,16 @@ def extract_( extractor_id = extractor.identifier - if generate_hash: - extractor_id += f"-{code_hash}" - _logger.info(f"Using extractor {extractor.identifier}") if cache_dir: cache_dir.mkdir(parents=True, exist_ok=True) - feat_output_dir = output_dir / extractor_id + feat_output_dir = ( + output_dir / f"{extractor_id}-{code_hash}" + if generate_hash + else output_dir / extractor_id + ) # Collect slides for preprocessing if wsi_list is not None: diff --git a/tests/test_encoders.py b/tests/test_encoders.py index ddce5c5a..3edef575 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -33,7 +33,7 @@ # They are not all, just one case that is accepted for each encoder used_extractor = { - EncoderName.CHIEF: ExtractorName.CHIEF_CTRANSPATH, + EncoderName.CHIEF_CTRANSPATH: ExtractorName.CHIEF_CTRANSPATH, EncoderName.COBRA: ExtractorName.CONCH, EncoderName.EAGLE: ExtractorName.CTRANSPATH, EncoderName.GIGAPATH: ExtractorName.GIGAPATH, @@ -73,7 +73,7 @@ def test_if_encoding_crashes(*, tmp_path: Path, encoder: EncoderName): ) cuda_required = [ - EncoderName.CHIEF, + EncoderName.CHIEF_CTRANSPATH, EncoderName.COBRA, EncoderName.GIGAPATH, EncoderName.MADELEINE, @@ -106,6 +106,28 @@ def test_if_encoding_crashes(*, tmp_path: Path, encoder: EncoderName): feat_filename=feat_filename, coords=coords, ) + elif encoder == EncoderName.PRISM: + # Eagle requires the aggregated features, so we generate new ones + # with same name and coordinates as the other ctranspath feats. + agg_feat_dir = tmp_path / "agg_output" + agg_feat_dir.mkdir() + slide_df = pd.read_csv(slide_path) + feature_filenames = [Path(path).stem for path in slide_df["slide_path"]] + + for feat_filename in feature_filenames: + # Read the coordinates from the ctranspath feature file + ctranspath_file = feature_dir / f"{feat_filename}.h5" + with h5py.File(ctranspath_file, "r") as h5_file: + coords: np.ndarray = h5_file["coords"][:] # type: ignore + create_random_feature_file( + tmp_path=agg_feat_dir, + min_tiles=32, + max_tiles=32, + feat_dim=input_dims[ExtractorName.VIRCHOW_FULL], + extractor_name="virchow-full", + feat_filename=feat_filename, + coords=coords, + ) elif encoder == EncoderName.TITAN: # A random conch1_5 feature does not work with titan so we just download # a real one