Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/stamp/encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def init_slide_encoder_(

selected_encoder: Encoder = Gigapath()

case EncoderName.CHIEF:
case EncoderName.CHIEF_CTRANSPATH:
from stamp.encoding.encoder.chief import CHIEF

selected_encoder: Encoder = CHIEF()
Expand Down Expand Up @@ -140,7 +140,7 @@ def init_patient_encoder_(

selected_encoder: Encoder = Gigapath()

case EncoderName.CHIEF:
case EncoderName.CHIEF_CTRANSPATH:
from stamp.encoding.encoder.chief import CHIEF

selected_encoder: Encoder = CHIEF()
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class EncoderName(StrEnum):
COBRA = "cobra"
EAGLE = "eagle"
CHIEF = "chief"
CHIEF_CTRANSPATH = "chief"
TITAN = "titan"
GIGAPATH = "gigapath"
MADELEINE = "madeleine"
Expand Down
35 changes: 34 additions & 1 deletion src/stamp/encoding/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ def _read_h5(
raise ValueError(
f"Feature file does not have extractor's name in the metadata: {os.path.basename(h5_path)}"
)
return feats, coords, extractor

return feats, coords, _resolve_extractor_name(extractor)

def _save_features_(
self, output_path: Path, feats: np.ndarray, feat_type: str
Expand All @@ -215,3 +216,35 @@ def _save_features_(

Path(tmp_h5_file.name).rename(output_path)
_logger.debug(f"saved features to {output_path}")


def _resolve_extractor_name(raw: str) -> ExtractorName:
"""
Resolve an extractor string to a valid ExtractorName.

Handles:
- exact matches ('gigapath', 'virchow-full')
- versioned strings like 'gigapath-ae23d', 'virchow-full-2025abc'
Raises ValueError if the base name is not recognized.
"""
if not raw:
raise ValueError("Empty extractor string")

name = str(raw).strip().lower()

# Exact match
for e in ExtractorName:
if name == e.value.lower():
return e

# Versioned form: '<enum-value>-something'
for e in ExtractorName:
if name.startswith(e.value.lower() + "-"):
return e

# Otherwise fail
raise ValueError(
f"Unknown extractor '{raw}'. "
f"Expected one of {[e.value for e in ExtractorName]} "
f"or a versioned variant like '<name>-<hash>'."
)
13 changes: 11 additions & 2 deletions src/stamp/encoding/encoder/chief.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self) -> None:
model.load_state_dict(chief, strict=True)
super().__init__(
model=model,
identifier=EncoderName.CHIEF,
identifier=EncoderName.CHIEF_CTRANSPATH,
precision=torch.float32,
required_extractors=[
ExtractorName.CHIEF_CTRANSPATH,
Expand Down Expand Up @@ -178,7 +178,16 @@ def encode_patients_(
for _, row in group.iterrows():
slide_filename = row[filename_label]
h5_path = os.path.join(feat_dir, slide_filename)
feats, _ = self._validate_and_read_features(h5_path=h5_path)
# Skip if not an .h5 file
if not h5_path.endswith(".h5"):
tqdm.write(f"Skipping {slide_filename} (not an .h5 file)")
continue

try:
feats, coords = self._validate_and_read_features(h5_path=h5_path)
except (FileNotFoundError, ValueError, OSError) as e:
tqdm.write(f"Skipping {slide_filename}: {e}")
continue
feats_list.append(feats)

if not feats_list:
Expand Down
11 changes: 10 additions & 1 deletion src/stamp/encoding/encoder/gigapath.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,16 @@ def encode_patients_(
slide_filename = row[filename_label]
h5_path = os.path.join(feat_dir, slide_filename)

feats, coords = self._validate_and_read_features(h5_path=h5_path)
# Skip if not an .h5 file
if not h5_path.endswith(".h5"):
tqdm.write(f"Skipping {slide_filename} (not an .h5 file)")
continue

try:
feats, coords = self._validate_and_read_features(h5_path=h5_path)
except (FileNotFoundError, ValueError, OSError) as e:
tqdm.write(f"Skipping {slide_filename}: {e}")
continue

# Get the mpp of one slide and check that the rest have the same
if slides_mpp < 0:
Expand Down
11 changes: 10 additions & 1 deletion src/stamp/encoding/encoder/titan.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,16 @@ def encode_patients_(
slide_filename = row[filename_label]
h5_path = os.path.join(feat_dir, slide_filename)

feats, coords = self._validate_and_read_features(h5_path=h5_path)
# Skip if not an .h5 file
if not h5_path.endswith(".h5"):
tqdm.write(f"Skipping {slide_filename} (not an .h5 file)")
continue

try:
feats, coords = self._validate_and_read_features(h5_path=h5_path)
except (FileNotFoundError, ValueError, OSError) as e:
tqdm.write(f"Skipping {slide_filename}: {e}")
continue

# Get the mpp of one slide and check that the rest have the same
if slides_mpp < 0:
Expand Down
9 changes: 5 additions & 4 deletions src/stamp/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,16 @@ def extract_(

extractor_id = extractor.identifier

if generate_hash:
extractor_id += f"-{code_hash}"

_logger.info(f"Using extractor {extractor.identifier}")

if cache_dir:
cache_dir.mkdir(parents=True, exist_ok=True)

feat_output_dir = output_dir / extractor_id
feat_output_dir = (
output_dir / f"{extractor_id}-{code_hash}"
if generate_hash
else output_dir / extractor_id
)

# Collect slides for preprocessing
if wsi_list is not None:
Expand Down
26 changes: 24 additions & 2 deletions tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

# They are not all, just one case that is accepted for each encoder
used_extractor = {
EncoderName.CHIEF: ExtractorName.CHIEF_CTRANSPATH,
EncoderName.CHIEF_CTRANSPATH: ExtractorName.CHIEF_CTRANSPATH,
EncoderName.COBRA: ExtractorName.CONCH,
EncoderName.EAGLE: ExtractorName.CTRANSPATH,
EncoderName.GIGAPATH: ExtractorName.GIGAPATH,
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_if_encoding_crashes(*, tmp_path: Path, encoder: EncoderName):
)

cuda_required = [
EncoderName.CHIEF,
EncoderName.CHIEF_CTRANSPATH,
EncoderName.COBRA,
EncoderName.GIGAPATH,
EncoderName.MADELEINE,
Expand Down Expand Up @@ -106,6 +106,28 @@ def test_if_encoding_crashes(*, tmp_path: Path, encoder: EncoderName):
feat_filename=feat_filename,
coords=coords,
)
elif encoder == EncoderName.PRISM:
# Eagle requires the aggregated features, so we generate new ones
# with same name and coordinates as the other ctranspath feats.
agg_feat_dir = tmp_path / "agg_output"
agg_feat_dir.mkdir()
slide_df = pd.read_csv(slide_path)
feature_filenames = [Path(path).stem for path in slide_df["slide_path"]]

for feat_filename in feature_filenames:
# Read the coordinates from the ctranspath feature file
ctranspath_file = feature_dir / f"{feat_filename}.h5"
with h5py.File(ctranspath_file, "r") as h5_file:
coords: np.ndarray = h5_file["coords"][:] # type: ignore
create_random_feature_file(
tmp_path=agg_feat_dir,
min_tiles=32,
max_tiles=32,
feat_dim=input_dims[ExtractorName.VIRCHOW_FULL],
extractor_name="virchow-full",
feat_filename=feat_filename,
coords=coords,
)
elif encoder == EncoderName.TITAN:
# A random conch1_5 feature does not work with titan so we just download
# a real one
Expand Down