From 47105a3f48e90823e3b51aa7c55aefc3310623c2 Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Fri, 9 May 2025 18:33:57 +0200 Subject: [PATCH 01/10] added attention UI --- pyproject.toml | 2 + src/stamp/__main__.py | 22 + src/stamp/config.py | 3 +- src/stamp/config.yaml | 20 + src/stamp/heatmaps/__init__.py | 42 + src/stamp/heatmaps/attention_ui.py | 981 +++++++++++++++++++++++ src/stamp/heatmaps/config.py | 20 + src/stamp/modeling/vision_transformer.py | 156 +++- 8 files changed, 1230 insertions(+), 16 deletions(-) create mode 100644 src/stamp/heatmaps/attention_ui.py diff --git a/pyproject.toml b/pyproject.toml index e02f55ce..e80e9ec0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "jaxtyping>=0.2.36", "lightning>=2.4.0", "matplotlib>=3.9.2", + "napari>=0.6.0", "numpy>=2.2.2", "opencv-python>=4.10.0.84", "openpyxl>=3.1.5", @@ -35,6 +36,7 @@ dependencies = [ "pandas>=2.2.3", "pillow>=11.1.0", "pydantic>=2.10.3", + "pyqt5>=5.15.11", "pyyaml>=6.0.2", "scikit-learn>=1.5.2", "scipy>=1.15.1", diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 940db944..70eaef97 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -210,6 +210,27 @@ def _run_cli(args: argparse.Namespace) -> None: default_slide_mpp=config.heatmaps.default_slide_mpp, ) + case "attention_ui": + from stamp.heatmaps import attention_ui_ + + if config.attention_ui is None: + raise ValueError("no attention configuration supplied") + + _add_file_handle_(_logger, output_dir=config.attention_ui.output_dir) + _logger.info( + "using the following configuration:\n" + f"{yaml.dump(config.attention_ui.model_dump(mode='json'))}" + ) + attention_ui_( + feature_dir=config.attention_ui.feature_dir, + wsi_dir=config.attention_ui.wsi_dir, + checkpoint_path=config.attention_ui.checkpoint_path, + output_dir=config.attention_ui.output_dir, + slide_paths=config.attention_ui.slide_paths, + device=config.attention_ui.device, + default_slide_mpp=config.attention_ui.default_slide_mpp, + ) + case _: raise RuntimeError( "unreachable: the argparser should only allow valid commands" @@ -261,6 +282,7 @@ def main() -> None: ) commands.add_parser("config", help="Print the loaded configuration") commands.add_parser("heatmaps", help="Generate heatmaps for a trained model") + commands.add_parser("attention_ui", help="Provides an interactive UI for exploring attention maps") args = parser.parse_args() diff --git a/src/stamp/config.py b/src/stamp/config.py index 7e154d7a..902e5d36 100644 --- a/src/stamp/config.py +++ b/src/stamp/config.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, ConfigDict -from stamp.heatmaps.config import HeatmapConfig +from stamp.heatmaps.config import HeatmapConfig, AttentionUIConfig from stamp.modeling.config import CrossvalConfig, DeploymentConfig, TrainConfig from stamp.preprocessing.config import PreprocessingConfig from stamp.statistics import StatsConfig @@ -18,3 +18,4 @@ class StampConfig(BaseModel): statistics: StatsConfig | None = None heatmaps: HeatmapConfig | None = None + attention_ui: AttentionUIConfig | None = None \ No newline at end of file diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 21ac06b3..d6eadfd7 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -196,6 +196,7 @@ statistics: #- "/path/to/split-3/patient-preds.csv" #- "/path/to/split-4/patient-preds.csv" + heatmaps: output_dir: "/path/to/save/files/to" @@ -217,3 +218,22 @@ heatmaps: # The number of top- and bottom-scoring tiles to extract #topk: 5 #bottomk: 5 + + +attention_ui: + output_dir: "/path/to/save/files/to" + + # Directory the extracted features are saved in. + feature_dir: "/path/your/extracted/features/are/stored/in" + + wsi_dir: "/path/containing/whole/slide/images/to/extract/features/from" + + # Path of the model to generate the attention maps with. + checkpoint_path: "/path/to/model.ckpt" + + # Slides to generate the attention maps for. + # The slide paths have to be specified relative to `wsi_dir`. + # If not specified, stamp will allow processing for all slides in `wsi_dir`. + #slide_paths: + #- slide1.svs + #- slide2.mrxs \ No newline at end of file diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index fbe72648..44ecacb8 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -8,6 +8,7 @@ import numpy as np import openslide import torch +import napari from jaxtyping import Float, Integer from matplotlib.axes import Axes from matplotlib.patches import Patch @@ -21,6 +22,7 @@ from stamp.modeling.vision_transformer import VisionTransformer from stamp.preprocessing import supported_extensions from stamp.preprocessing.tiling import Microns, SlideMPP, TilePixels, get_slide_mpp_ +from stamp.heatmaps.attention_ui import show_attention_ui _logger = logging.getLogger("stamp") @@ -316,3 +318,43 @@ def heatmaps_( fig.savefig(slide_output_dir / f"overview-{h5_path.stem}.png") plt.close(fig) + + + + +def attention_ui_( + *, + feature_dir: Path, + wsi_dir: Path, + checkpoint_path: Path, + output_dir: Path, + slide_paths: Iterable[Path] | None, + device: DeviceLikeType, + default_slide_mpp: SlideMPP | None +) -> None: + + with torch.no_grad(): + + # Collect slides to generate attention maps for + if slide_paths is not None: + wsis_to_process_all = (wsi_dir / slide for slide in slide_paths) + else: + wsis_to_process_all = ( + p for ext in supported_extensions for p in wsi_dir.glob(f"**/*{ext}") + ) + + # Check of a corresponding feature file exists + wsis_to_process = [] + for wsi_path in wsis_to_process_all: + h5_path = feature_dir / wsi_path.with_suffix(".h5").name + + if not h5_path.exists(): + _logger.info(f"could not find matching h5 file at {h5_path}. Skipping...") + continue + + wsis_to_process.append(str(wsi_path)) + + + # Launch the UI + viewer = show_attention_ui(feature_dir, wsis_to_process, checkpoint_path, output_dir, slide_paths, device, default_slide_mpp) + napari.run() diff --git a/src/stamp/heatmaps/attention_ui.py b/src/stamp/heatmaps/attention_ui.py new file mode 100644 index 00000000..1b8e6e2f --- /dev/null +++ b/src/stamp/heatmaps/attention_ui.py @@ -0,0 +1,981 @@ +import napari +import torch +import time +from torch import Tensor +from torch._prims_common import DeviceLikeType +from jaxtyping import Float, Integer +from typing import cast +from collections.abc import Iterable +from pathlib import Path +import numpy as np +import matplotlib.cm as cm +import openslide +import h5py +from scipy.spatial.distance import cdist +from typing import Union +from qtpy.QtGui import QPixmap, QImage +from qtpy.QtWidgets import ( + QComboBox, QPushButton, QVBoxLayout, QWidget, QLabel, QScrollArea, + QListWidget, QApplication, QSlider, QHBoxLayout, QFrame +) +from qtpy.QtCore import Qt +from stamp.modeling.lightning_model import LitVisionTransformer +from stamp.preprocessing.tiling import Microns, SlideMPP, SlidePixels, get_slide_mpp_ +from stamp.modeling.data import get_coords, get_stride + +# Define a generic integer type that can be either Python int or numpy int64 +IntType = Union[int, np.int64, np.int32] + + +def _vals_to_im( + scores: Float[Tensor, "tile feat"], + coords_norm: Integer[Tensor, "tile coord"], +) -> Float[Tensor, "width height category"]: + """Arranges scores in a 2d grid according to coordinates""" + size = coords_norm.max(0).values.flip(0) + 1 + im = torch.ones((*size.tolist(), *scores.shape[1:])).type_as(scores)*-1e-8 + + flattened_im = im.flatten(end_dim=1) + flattened_coords = coords_norm[:, 1] * im.shape[1] + coords_norm[:, 0] + flattened_im[flattened_coords] = scores + + im = flattened_im.reshape_as(im) + + return im + + +def _get_thumb(slide, slide_mpp: SlideMPP) -> np.ndarray: + """Get thumbnail of the slide at the specified MPP and tile size""" + # Get the thumbnail image from the slide + dims_um = np.array(slide.dimensions) * slide_mpp + thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)) + thumb = np.array(thumb) + return thumb/255 + + +def _patch_to_pixmap(patch_image): + """Convert patch image to QPixmap for display in QLabel using NumPy""" + + # Resize for better visualization if needed + display_size = 200 # Adjust this value as needed + aspect_ratio = patch_image.width / patch_image.height + + if aspect_ratio > 1: + # Wider than tall + new_width = display_size + new_height = int(display_size / aspect_ratio) + else: + # Taller than wide + new_height = display_size + new_width = int(display_size * aspect_ratio) + + patch_image = patch_image.resize((new_width, new_height)) + + # Convert PIL image to numpy array + img_array = np.array(patch_image) + + # Create QImage from NumPy array + height, width, channels = img_array.shape + bytes_per_line = channels * width + + # Create QImage (RGB format) + q_image = QImage( + img_array.data, + width, + height, + bytes_per_line, + QImage.Format_RGB888 + ) + + return QPixmap.fromImage(q_image) + + +class AttentionViewer: + def __init__( + self, + feature_dir: Path, + wsis_to_process: Iterable[str], + checkpoint_path: Path, + output_dir: Path, + slide_paths: Iterable[Path] | None, + device: DeviceLikeType, + default_slide_mpp: SlideMPP | None + ): + """ + Interactive viewer for images with click-based heatmap generation. + + Parameters: + ----------- + image : np.ndarray + The base image to display and interact with + _heatmap_generator : callable, optional + A function that takes coordinates (y, x) and returns a heatmap + If None, a simple gaussian will be used + """ + self.feature_dir = feature_dir + self.wsis_to_process = wsis_to_process + self.checkpoint_path = checkpoint_path + self.output_dir = output_dir + self.slide_paths = slide_paths + self.device = device + self.default_slide_mpp = default_slide_mpp + + # Initialize model + self.model = LitVisionTransformer.load_from_checkpoint(checkpoint_path).to(device).eval() + + # Initialize napari viewer + self.viewer = napari.Viewer(title=' Histopathology Attention Explorer') + + # Create dummy image layer + self.image = np.zeros((100,100,3), dtype=np.float32) # Placeholder for the image + self.image_layer = self.viewer.add_image( + self.image, + name='Image' + ) + if len(self.image.shape) == 3 and self.image.shape[2] in [3, 4]: # RGB or RGBA + self.height, self.width = self.image.shape[0], self.image.shape[1] + else: # Grayscale + self.height, self.width = self.image.shape + + # Initialize other attributes + self.slide = None + self.attention_map = np.zeros((100,100), dtype=np.float32) # Placeholder for the attention map + self.attention_weights = None + self.token_attn = None + self._attention_update_debounce = 500 + self._last_attention_update_time = 0 + self.map_coords = None + self.coords_tile_slide_px = None + self.tile_size_slide_px = None + self.selected_token_idx = None + self.selected_filename = None + self.num_layer = 0 + self.num_head = 0 + + + # Initialize empty heatmap + self.heatmap = np.zeros((self.height, self.width, 4), dtype=float) + self.heatmap_layer = self.viewer.add_image( + self.heatmap, + name='Attention', + rgb=True, + visible=True, + opacity=0.7 + ) + + # Initialize empty highlight map + self.highlight_mask = np.zeros((self.height, self.width, 4), dtype=float) + self.highlight_layer = self.viewer.add_image( + self.highlight_mask, + name='Top-k Highlight', + rgb=True, + visible=True, + opacity=1.0 + ) + + # Initialize points layer + self.clicked_points = [] + self.points_layer = self.viewer.add_points( + name='Selected Point', + size=10, + face_color='green', + symbol='x', + n_dimensional=True + ) + self._last_processed_point_count = 0 + self._updating_points = False + self.points_layer.events.data.connect(self._on_add_point) + + # Add other UI elements + self._add_file_selection_ui() + self._add_config_selection_ui() + self._add_patch_display_widget() + + # Disable UI elements until a file is selected + self._set_ui_enabled(False) + + # Print instructions + print("Click on the image to select points and generate attention heatmaps") + + + + #### ADDING UI ELEMENTS #### + + def _add_file_selection_ui(self): + """Add file selection dropdown and process button""" + # Create a widget container + file_selection_widget = QWidget() + layout = QVBoxLayout() + file_selection_widget.setLayout(layout) + + # Add a label + label = QLabel("Available files:") + layout.addWidget(label) + + # Create a list widget for file selection (scrollable) + self.file_list = QListWidget() + self.file_list.addItems(self.wsis_to_process) + + # Set a reasonable height to show multiple items + self.file_list.setMinimumHeight(150) + + # Make it scrollable if many items + self.file_list.setVerticalScrollBarPolicy(1) # 1 = Always show scrollbar + + # Add the list to the layout + layout.addWidget(self.file_list) + + # Create a process button + self.process_button = QPushButton("Process Selected File") + self.process_button.clicked.connect(self._on_process_file) + layout.addWidget(self.process_button) + + # Add the widget to napari viewer as a dock widget + self.viewer.window.add_dock_widget( + file_selection_widget, + name="File Selection", + area="right" + ) + + + + def _add_config_selection_ui(self): + """Add UI controls for selecting attention layer and head""" + # Create a widget container + selection_widget = QWidget() + layout = QVBoxLayout() + selection_widget.setLayout(layout) + + # Replace the direction slider section in _add_layer_head_selection_ui with: + + # === DIRECTION SELECTION === + direction_label = QLabel("Attention Handling:") + layout.addWidget(direction_label) + + # Create direction selection dropdown + direction_layout = QHBoxLayout() + + # Direction dropdown + self.attention_handling = QComboBox() + self.attention_handling.addItem("From selected tile to others", 0) + self.attention_handling.addItem("From other tiles to selected", 1) + self.attention_handling.addItem("Deviation of overall given attention", 2) + self.attention_handling.addItem("Deviation of overall received attention", 3) + self.attention_handling.addItem("Mean of overall given attention", 4) + self.attention_handling.addItem("Mean of overall received attention", 5) + self.attention_handling.addItem("Class token attention", 6) + + # Connect the dropdown to the update function + self.attention_handling.currentIndexChanged.connect(self._on_update_attention_map) + + direction_layout.addWidget(self.attention_handling) + layout.addLayout(direction_layout) + + + # === LAYER SELECTION === + layer_label = QLabel("Number of Network Layer\n(first->last):") + layout.addWidget(layer_label) + + # Create layer selection controls with arrows and slider + layer_layout = QHBoxLayout() + + # Layer slider + self.layer_slider = QSlider(Qt.Horizontal) + self.layer_slider.setMinimum(0) + self.layer_slider.setMaximum(2) # Will be updated with actual layer count + self.layer_slider.setValue(0) + self.layer_slider.valueChanged.connect( + lambda value: ( + self.layer_value_label.setText(str(value)), + self._on_update_attention_map() + ) + ) + layer_layout.addWidget(self.layer_slider) + + # Left arrow button for layer + self.layer_left_btn = QPushButton("←") + self.layer_left_btn.setMaximumWidth(30) + self.layer_left_btn.clicked.connect(lambda: self._adjust_slider(value=-1, ui_element=self.layer_slider)) + layer_layout.addWidget(self.layer_left_btn) + + # Right arrow button for layer + self.layer_right_btn = QPushButton("→") + self.layer_right_btn.setMaximumWidth(30) + self.layer_right_btn.clicked.connect(lambda: self._adjust_slider(value=1, ui_element=self.layer_slider)) + layer_layout.addWidget(self.layer_right_btn) + + # Layer value display + self.layer_value_label = QLabel("0") + self.layer_value_label.setMinimumWidth(25) + self.layer_value_label.setAlignment(Qt.AlignCenter) + layer_layout.addWidget(self.layer_value_label) + + layout.addLayout(layer_layout) + + # === HEAD SELECTION === + head_label = QLabel("Number of Prediction Head\n(-1 for average):") + layout.addWidget(head_label) + + # Create head selection controls with arrows and slider + head_layout = QHBoxLayout() + + # Head slider + self.head_slider = QSlider(Qt.Horizontal) + self.head_slider.setMinimum(-1) + self.head_slider.setMaximum(8) # Will be updated with actual head count + self.head_slider.setValue(0) + self.head_slider.valueChanged.connect( + lambda value: ( + self.head_value_label.setText(str(value)), + self._on_update_attention_map() + ) + ) + head_layout.addWidget(self.head_slider) + + # Left arrow button for head + self.head_left_btn = QPushButton("←") + self.head_left_btn.setMaximumWidth(30) + self.head_left_btn.clicked.connect(lambda: self._adjust_slider(-1, self.head_slider)) + head_layout.addWidget(self.head_left_btn) + + # Right arrow button for head + self.head_right_btn = QPushButton("→") + self.head_right_btn.setMaximumWidth(30) + self.head_right_btn.clicked.connect(lambda: self._adjust_slider(1, self.head_slider)) + head_layout.addWidget(self.head_right_btn) + + # Head value display + self.head_value_label = QLabel("0") + self.head_value_label.setMinimumWidth(25) + self.head_value_label.setAlignment(Qt.AlignCenter) + head_layout.addWidget(self.head_value_label) + + layout.addLayout(head_layout) + + + # === Top-k SELECTION === + topk_label = QLabel("Top-k tiles to highlight:") + layout.addWidget(topk_label) + + # Create top-k selection controls with arrows and slider + topk_layout = QHBoxLayout() + + # Top-k slider + self.topk_slider = QSlider(Qt.Horizontal) + self.topk_slider.setMinimum(0) + self.topk_slider.setMaximum(50) + self.topk_slider.setValue(5) + self.topk_slider.valueChanged.connect( + lambda value: ( + self.topk_value_label.setText(str(value)), + self._on_update_attention_map() + ) + ) + topk_layout.addWidget(self.topk_slider) + + # Left arrow button for top-k + self.topk_left_btn = QPushButton("←") + self.topk_left_btn.setMaximumWidth(30) + self.topk_left_btn.clicked.connect(lambda: self._adjust_slider(-1, self.topk_slider)) + topk_layout.addWidget(self.topk_left_btn) + + # Right arrow button for top-k + self.topk_right_btn = QPushButton("→") + self.topk_right_btn.setMaximumWidth(30) + self.topk_right_btn.clicked.connect(lambda: self._adjust_slider(1, self.topk_slider)) + topk_layout.addWidget(self.topk_right_btn) + + # Top-k value display + self.topk_value_label = QLabel("5") + self.topk_value_label.setMinimumWidth(25) + self.topk_value_label.setAlignment(Qt.AlignCenter) + topk_layout.addWidget(self.topk_value_label) + + layout.addLayout(topk_layout) + + # Add a horizontal separator + separator = QFrame() + separator.setFrameShape(QFrame.HLine) + separator.setFrameShadow(QFrame.Sunken) + layout.addWidget(separator) + + # === LOAD PATCHES BUTTON === + self.load_patches_btn = QPushButton("Load Selected & Top-k Patches") + self.load_patches_btn.clicked.connect(self._load_tile_patches) + layout.addWidget(self.load_patches_btn) + + # === Add the widget to napari viewer as a dock widget === + self.viewer.window.add_dock_widget( + selection_widget, + name="Attention Parameters", + area="right" + ) + + + def _add_patch_display_widget(self): + """Create a widget to display the selected patch and top-k patches""" + # Create main widget + self.patch_display_widget = QWidget() + layout = QHBoxLayout() + self.patch_display_widget.setLayout(layout) + + # Create a scroll area to contain patches + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) + + # Create container widget for patches + self.patches_container = QWidget() + self.patches_layout = QHBoxLayout() + self.patches_container.setLayout(self.patches_layout) + + # Add to scroll area + scroll_area.setWidget(self.patches_container) + layout.addWidget(scroll_area) + + # Add the widget to napari viewer as a dock widget at the bottom + self.viewer.window.add_dock_widget( + self.patch_display_widget, + name="Tile Patches", + area="bottom" + ) + + + + ### UI HANDLING ### + + def _set_ui_enabled(self, enabled: bool): + """Enable or disable all UI elements""" + + # Enable points layer + if hasattr(self, 'points_layer'): + self.points_layer.editable = enabled + + # Enable patch loading button + if hasattr(self, 'load_patches_btn'): + self.load_patches_btn.setEnabled(enabled) + + # Enable attention handling dropdown + if hasattr(self, 'attention_handling'): + self.attention_handling.setEnabled(enabled) + + # Enable all sliders + for slider_name in ['layer_slider', 'head_slider', 'topk_slider']: + if hasattr(self, slider_name): + getattr(self, slider_name).setEnabled(enabled) + + # Enable all arrow buttons + for btn_name in ['layer_left_btn', 'layer_right_btn', 'head_left_btn', + 'head_right_btn', 'topk_left_btn', 'topk_right_btn']: + if hasattr(self, btn_name): + getattr(self, btn_name).setEnabled(enabled) + + # Process UI events to update display + QApplication.processEvents() + + + def _adjust_slider(self, value=1, ui_element=None): + """Adjust the slider value by a given amount""" + current = ui_element.value() + if (value>0 and current < ui_element.maximum()) or (value<0 and current > ui_element.minimum()): + ui_element.setValue(current + value) + + + def _update_viewer_image(self, new_image: np.ndarray): + """ + Update the viewer image and reset heatmap after loading a new file + + Parameters: + ----------- + new_image : np.ndarray + The new image to display in the viewer + """ + # Update the image layer + self.image = new_image + self.image_layer.data = self.image + + # Update dimensions based on the new image + if len(self.image.shape) == 3 and self.image.shape[2] in [3, 4]: # RGB or RGBA + self.height, self.width = self.image.shape[0], self.image.shape[1] + else: # Grayscale + self.height, self.width = self.image.shape + + # Reset the heatmap + self.heatmap = np.zeros((self.height, self.width, 4), dtype=float) + self.heatmap_layer.data = self.heatmap + + # Clear any existing points + self.points_layer.data = np.empty((0, 2)) + self.clicked_points = [] + self._last_processed_point_count = 0 + + # Reset selected token + self.selected_token_idx = None + + # Reset viewer scale and position to fit the new image + self.viewer.reset_view() + + # Set active layer to points layer and mode to add + self.viewer.layers.selection.active = self.points_layer + self.points_layer.mode = 'add' + + print(f"Viewer updated with new image") + + + + + + ### FILE PROCESSING ### + + def _on_process_file(self): + """Handle the Process button click""" + selected_items = self.file_list.selectedItems() + + if selected_items: + # Get the selected filename + self.selected_filename = selected_items[0].text() + print(f"Selected file: {self.selected_filename}") + + self.process_selected_file(self.selected_filename) + self.load_selected_attention_map() + + else: + print("No file selected! Please select a file first.") + + + + def process_selected_file(self, wsi_path): + """Load the selected file and the corresponding attention map""" + + # Disable UI controls + self._set_ui_enabled(False) + + # Use QApplication to process events and update the UI + QApplication.processEvents() + + try: + + print(f"Processing file: {wsi_path}") + + with torch.no_grad(): + + # Load WSI + wsi_path = Path(wsi_path) + h5_path = self.feature_dir / wsi_path.with_suffix(".h5").name + print(f"Creating attention map for {wsi_path.name}") + + self.slide = openslide.open_slide(wsi_path) + slide_mpp = get_slide_mpp_(self.slide, default_mpp=self.default_slide_mpp) + assert slide_mpp is not None, "could not determine slide MPP" + + with h5py.File(h5_path) as h5: + feats = ( + torch.tensor( + h5["feats"][:] # pyright: ignore[reportIndexIssue] + ) + .float() + .to(self.device) + ) + coords_um = get_coords(h5).coords_um + stride_um = Microns(get_stride(coords_um)) + + if h5.attrs.get("unit") == "um": + self.tile_size_slide_px = SlidePixels( + int(round(cast(float, h5.attrs["tile_size_um"]) / slide_mpp)) + ) + else: + self.tile_size_slide_px = SlidePixels(int(round(256 / slide_mpp))) + + # grid coordinates, i.e. the top-left most tile is (0, 0), the one to its right (0, 1) etc. + self.map_coords = (coords_um / stride_um).round().long() + + # coordinates as used by OpenSlide (used to extract top/bottom k tiles) + self.coords_tile_slide_px = torch.round(coords_um / slide_mpp).long() + + # Score for the entire slide + self.attention_weights = self.model.vision_transformer.get_attention_maps( + bags=feats.unsqueeze(0), + coords=coords_um.unsqueeze(0), + mask=torch.zeros(1, len(feats), dtype=torch.bool, device=self.device), + ) + + # Determine number of heads and layers and update UI elements + num_layers = len(self.attention_weights) + num_heads = self.attention_weights[0].shape[1] + self.layer_slider.setMaximum(num_layers - 1) + self.head_slider.setMaximum(num_heads - 1) + + # Get thumbnail of the slide + self.image = _get_thumb(self.slide, slide_mpp) + + # Update the viewer with the new image + self._update_viewer_image(self.image) + + + finally: + # Re-enable UI controls + self._set_ui_enabled(True) + + + + def load_selected_attention_map(self): + + # Get attention weights + # Choose layer + self.attention_map = self.attention_weights[self.num_layer] # Shape: [batch, heads, tokens, tokens]) + # Choose head (or average) + if self.num_head == -1: + # Average over heads + self.attention_map = self.attention_map.mean(dim=1) + else: + self.attention_map = self.attention_map[:,self.num_head,...] # Shape: [batch, tokens, tokens] + # Cut out batch dimension + self.attention_map = self.attention_map[0,...] # Shape: [tokens, tokens] + + # Normalize attention map to [0, 1] by using percentiles (not considering cls token) + percentile_low = np.percentile(self.attention_map[1:,1:], 0.5) + percentile_high = np.percentile(self.attention_map[1:,1:], 99.5) + self.attention_map = (self.attention_map - percentile_low) / (percentile_high - percentile_low + 1e-8) + + + + def highlight_top_k_tiles(self): + + if self.selected_token_idx is None: + print("No token selected. Click on the image first.") + return + + # Create a new highlight mask + k = min(self.topk_slider.value(), len(self.token_attn)) + highlight_mask = np.zeros((self.height, self.width, 4), dtype=float) + + if k > 0: + # Get top k indices with highest attention + top_k_values, top_k_indices = torch.topk(self.token_attn, k) + + # For each top tile, add a colored rectangle to the mask + for i, (score, idx) in enumerate(zip(top_k_values.cpu().numpy(), top_k_indices.cpu().numpy())): + # Get tile coordinates + x, y = self.map_coords[idx].cpu().numpy() + + # Convert to image coordinates (scaled by 8) + x_img, y_img = x * 8, y * 8 + tile_size = 8 # Assuming 8x8 pixels per tile + + # Create rectangular highlight for this tile + # Use a different color intensity based on rank (1st is most intense) + min_opacity = 0.5 + intensity = 1.0 - (i * min_opacity / k) # Decreasing intensity for lower ranks + + # Define rectangle in the highlight mask + y_start, y_end = max(0, y_img), min(self.height, y_img + tile_size) + x_start, x_end = max(0, x_img), min(self.width, x_img + tile_size) + + # Red with alpha based on score + highlight_mask[y_start:y_end, x_start:x_end] = [0.0, 0.6, 0.0, min(min_opacity, intensity + score * 0.3)] + + self.highlight_mask = highlight_mask + self.highlight_layer.data = self.highlight_mask + + + + def _load_tile_patches(self): + """Load and display the selected patch and top-k patches""" + if self.selected_token_idx is None: + print("No token selected. Click on the image first.") + return + + # Clear previous patches + while self.patches_layout.count(): + item = self.patches_layout.takeAt(0) + widget = item.widget() + if widget: + widget.deleteLater() + + # Get selected token patch + selected_patch = self.slide.read_region( + tuple(self.coords_tile_slide_px[self.selected_token_idx].tolist()), + 0, + (self.tile_size_slide_px, self.tile_size_slide_px), + ).convert("RGB") + + # Add selected patch with label + selected_frame = QFrame() + selected_layout = QVBoxLayout() + selected_frame.setLayout(selected_layout) + + # Create QLabel for image + selected_label = QLabel() + selected_pixmap = _patch_to_pixmap(selected_patch) + selected_label.setPixmap(selected_pixmap) + + # Create label text + text_label = QLabel(f"Selected-ID:{self.selected_token_idx}") + text_label.setAlignment(Qt.AlignCenter) + + selected_layout.addWidget(selected_label) + selected_layout.addWidget(text_label) + self.patches_layout.addWidget(selected_frame) + + # Add separator + separator = QFrame() + separator.setFrameShape(QFrame.VLine) + separator.setFrameShadow(QFrame.Sunken) + separator.setLineWidth(2) + separator.setMinimumWidth(5) + separator.setStyleSheet("background-color: #888888;") + self.patches_layout.addWidget(separator) + + # Get top-k patches + topk = min(self.topk_slider.value(), len(self.token_attn)) + if topk > 0: + for n, (score, index) in enumerate(zip(*self.token_attn.topk(topk))): + # Get patch + patch = self.slide.read_region( + tuple(self.coords_tile_slide_px[index].tolist()), + 0, + (self.tile_size_slide_px, self.tile_size_slide_px), + ).convert("RGB") + + # Create frame with layout + patch_frame = QFrame() + patch_layout = QVBoxLayout() + patch_frame.setLayout(patch_layout) + + # Create QLabel for image + patch_label = QLabel() + patch_pixmap = _patch_to_pixmap(patch) + patch_label.setPixmap(patch_pixmap) + + # Create label text + text_label = QLabel(f"Top-{n+1}-ID:{index} (Score:{score:.2f})") + text_label.setAlignment(Qt.AlignCenter) + + patch_layout.addWidget(patch_label) + patch_layout.addWidget(text_label) + self.patches_layout.addWidget(patch_frame) + + # Force update of the layout + self.patches_container.adjustSize() + QApplication.processEvents() + + + def _on_update_attention_map(self): + # Check if we have data to display + if self.attention_weights is None: + return + + # Simple debounce to avoid too frequent updates + current_time = time.time() * 1000 # Convert to milliseconds + if current_time - self._last_attention_update_time < self._attention_update_debounce: + return + self._last_attention_update_time = current_time + + # Get head and layer values from the sliders + self.num_layer = self.layer_slider.value() + self.num_head = self.head_slider.value() + + # Update attention map + self.load_selected_attention_map() + self._last_processed_point_count = 0 # Reset last processed point count + self._on_add_point() + + + + def _on_add_point(self): + """Handle points being added to the points layer""" + if not self.map_coords is None: + # Prevent recursive calls + if self._updating_points: + return + + # Check if points have been added + if len(self.points_layer.data) > self._last_processed_point_count: # If there's any data + # Keep only the last added point + last_point = self.points_layer.data[-1] + + # Convert to proper types + y, x = int(last_point[0]), int(last_point[1]) + + # Set the flag before updating to prevent recursion + self._updating_points = True + + try: + # Update heatmap based on the new point + self.update_heatmap(y-4, x-4) # 4 to center the point + + # Snap to selected token position + x_snapped, y_snapped = self.map_coords[self.selected_token_idx,:].tolist() + self.points_layer.data = np.array([[y_snapped*8+4, x_snapped*8+4]]) + self._last_processed_point_count = 1 # We now have 1 point + + # Update top-k tiles + self.highlight_top_k_tiles() + + # Print clicked coordinates + print(f"Clicked at coordinates: ({y},{x}). Selected token index: {self.selected_token_idx} at ({y_snapped*8+4},{x_snapped*8+4})") + finally: + # Reset the flag after updating + self._updating_points = False + + else: + print("No map coordinates available. Please load a file first.") + + + + def update_heatmap(self, y: IntType, x: IntType): + """Update the heatmap based on clicked position""" + # Generate new heatmap using the provided or default function + self.heatmap, self.selected_token_idx = self._heatmap_generator(y, x) + + # Update the heatmap layer + self.heatmap_layer.data = self.heatmap + + + def get_token_attention(self, selected_token_idx: IntType): + + # Get selected direction + selected_direction = self.attention_handling.currentData() + + # Get attention for selected token + + # Attention from selected to others + if selected_direction == 0: + token_attn = self.attention_map[selected_token_idx+1, 1:] # +1 to skip the cls token + + # Attention from others to selected + elif selected_direction == 1: + token_attn = self.attention_map[1:, selected_token_idx+1] # +1 to skip the cls token + + # Deviation of overall given attention + elif selected_direction == 2: + token_attn = torch.std(self.attention_map[1:, 1:], dim=0) + percentile_low = np.percentile(token_attn, 0.5) + percentile_high = np.percentile(token_attn, 99.5) + token_attn = (token_attn - percentile_low) / (percentile_high - percentile_low + 1e-8) + + # Deviation of overall received attention + elif selected_direction == 3: + token_attn = torch.std(self.attention_map[1:, 1:], dim=1) + percentile_low = np.percentile(token_attn, 0.5) + percentile_high = np.percentile(token_attn, 99.5) + token_attn = (token_attn - percentile_low) / (percentile_high - percentile_low + 1e-8) + + # Mean of overall given attention + elif selected_direction == 4: + token_attn = torch.mean(self.attention_map[1:, 1:], dim=0) + + # Mean of overall received attention + elif selected_direction == 5: + token_attn = torch.mean(self.attention_map[1:, 1:], dim=1) + + # Class token attention + elif selected_direction == 6: + token_attn = self.attention_map[0, 1:] # from cls to others + percentile_low = np.percentile(token_attn, 0.5) + percentile_high = np.percentile(token_attn, 99.5) + token_attn = (token_attn - percentile_low) / (percentile_high - percentile_low + 1e-8) + + else: + raise ValueError(f"Invalid direction selected: {selected_direction}") + + token_attn = np.clip(token_attn, 0, 1) + + return token_attn + + + + def _heatmap_generator(self, y: IntType, x: IntType): + """Heatmap generator - determines closest token to clicked position and extract inter-token attention""" + # Get the closest token to the clicked position + token_distances = cdist([(x, y)], self.map_coords.numpy(force=True)*8) # Upscale by 8 to match thumbnail size + selected_token_idx = np.argmin(token_distances) + + # Get attention for selected token + self.token_attn = self.get_token_attention(selected_token_idx) + + # Generate heatmap + cls_attn_map = _vals_to_im( + self.token_attn.unsqueeze(-1), # Add feature dimension + self.map_coords, + ).squeeze(-1) # Shape: [width, height] + + # Upscale by 8 to match the thumbnail size + cls_attn_map = cls_attn_map.repeat_interleave(8, dim=0).repeat_interleave(8, dim=1) + + # Normalize the heatmap to [0, 1] + # cls_attn_map = (cls_attn_map - cls_attn_map.min()) / (cls_attn_map.max() - cls_attn_map.min() + 1e-8) + + # Convert to numpy array + heatmap_values = cls_attn_map.numpy(force=True) + + # Get the colormap + colormap = cm.get_cmap('inferno') + + # Apply colormap to the values to get RGB + heatmap_rgba = colormap(heatmap_values) + + # Create a mask for zero and near-zero values (make these transparent) + threshold = 0.0 + zero_mask = heatmap_values < threshold + + # Set alpha channel to make zero-value regions fully transparent + heatmap_rgba[zero_mask, 3] = 0.0 + + # Scale the alpha for non-zero values by the desired opacity + heatmap_rgba[~zero_mask, 3] *= 1.0 + + return heatmap_rgba, selected_token_idx + + + + def show(self): + """Display the viewer and start the event loop""" + napari.run() + + +def show_attention_ui( + feature_dir: Path, + wsis_to_process: Iterable[str], + checkpoint_path: Path, + output_dir: Path, + slide_paths: Iterable[Path] | None, + device: DeviceLikeType, + default_slide_mpp: SlideMPP | None + ) -> AttentionViewer: + """ + Launch the attention UI. + + Parameters: + ----------- + feature_dir : Path + Directory containing feature files + wsis_to_process : Iterable[str] + List of WSI files to present for process + checkpoint_path : Path + Path to model checkpoint + output_dir : Path + Directory to save output files + slide_paths : Iterable[Path] | None + Paths to specific slide files + device : DeviceLikeType + Device to run model on + default_slide_mpp : SlideMPP | None + Default slide microns per pixel + + Returns: + -------- + AttentionViewer + The viewer instance for attention exploration + """ + viewer = AttentionViewer( + feature_dir, + wsis_to_process, + checkpoint_path, + output_dir, + slide_paths, + device, + default_slide_mpp + ) + return viewer \ No newline at end of file diff --git a/src/stamp/heatmaps/config.py b/src/stamp/heatmaps/config.py index f6bda6a4..0a3b75c0 100644 --- a/src/stamp/heatmaps/config.py +++ b/src/stamp/heatmaps/config.py @@ -24,3 +24,23 @@ class HeatmapConfig(BaseModel): default_slide_mpp: SlideMPP | None = None """MPP of the slide to use if none can be inferred from the WSI""" + + +class AttentionUIConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + output_dir: Path + + feature_dir: Path + wsi_dir: Path + checkpoint_path: Path + + slide_paths: list[Path] | None = None + + device: str = "cuda" if torch.cuda.is_available() else "cpu" + + topk: int = 0 + bottomk: int = 0 + + default_slide_mpp: SlideMPP | None = None + """MPP of the slide to use if none can be inferred from the WSI""" diff --git a/src/stamp/modeling/vision_transformer.py b/src/stamp/modeling/vision_transformer.py index 88225f9d..bd6d87f1 100755 --- a/src/stamp/modeling/vision_transformer.py +++ b/src/stamp/modeling/vision_transformer.py @@ -41,6 +41,7 @@ def __init__( super().__init__() self.heads = num_heads self.norm = nn.LayerNorm(dim) + self.last_attn_weights = None # Store attention weights if use_alibi: self.mhsa = MultiHeadALiBi( @@ -59,7 +60,8 @@ def forward( attn_mask: Bool[Tensor, "batch sequence sequence"] | None, # Help, my abstractions are leaking! alibi_mask: Bool[Tensor, "batch sequence sequence"], - ) -> Float[Tensor, "batch sequence proj_feature"]: + return_attention: bool = False, + ) -> Float[Tensor, "batch sequence proj_feature"] | tuple[Float[Tensor, "batch sequence proj_feature"], Float[Tensor, "batch heads sequence sequence"]]: """ Args: attn_mask: @@ -71,34 +73,89 @@ def forward( Which query-key pairs to apply ALiBi to. If this module was constructed using `use_alibi=False`, this has no effect. + return_attention: + If True, returns the attention weights alongside the output. """ x = self.norm(x) + + # Initialize attention weights with default shape + self.last_attn_weights = None + match self.mhsa: case nn.MultiheadAttention(): - attn_output, _ = self.mhsa( + attn_output, attn_weights = self.mhsa( x, x, x, - need_weights=False, + need_weights=True, + average_attn_weights=False, attn_mask=( attn_mask.repeat(self.mhsa.num_heads, 1, 1) if attn_mask is not None else None ), ) + self.last_attn_weights = attn_weights + case MultiHeadALiBi(): - attn_output = self.mhsa( - q=x, - k=x, - v=x, - coords_q=coords, - coords_k=coords, - attn_mask=attn_mask, - alibi_mask=alibi_mask, - ) + # Modified MultiHeadALiBi to return attention weights + if hasattr(self.mhsa, "return_attention_weights"): + try: + attn_output, attn_weights = self.mhsa( + q=x, + k=x, + v=x, + coords_q=coords, + coords_k=coords, + attn_mask=attn_mask, + alibi_mask=alibi_mask, + return_attention=True, + ) + self.last_attn_weights = attn_weights + except: + # If the return_attention param exists but fails, fall back + attn_output = self.mhsa( + q=x, + k=x, + v=x, + coords_q=coords, + coords_k=coords, + attn_mask=attn_mask, + alibi_mask=alibi_mask, + ) + # Create dummy attention weights to satisfy type checking + if return_attention: + print("Warning: Failed to return attention weights. Creating dummy weights.") + batch_size, seq_len, _ = x.shape + self.last_attn_weights = torch.zeros( + batch_size, self.heads, seq_len, seq_len, + device=x.device, dtype=x.dtype + ) + else: + attn_output = self.mhsa( + q=x, + k=x, + v=x, + coords_q=coords, + coords_k=coords, + attn_mask=attn_mask, + alibi_mask=alibi_mask, + ) + self.last_attn_weights = None case _ as unreachable: assert_never(unreachable) + if return_attention: + # Ensure we always return valid tensor for attention weights + if self.last_attn_weights is None: + # Create default attention weights if none were produced + batch_size, seq_len, _ = x.shape + self.last_attn_weights = torch.zeros( + batch_size, self.heads if hasattr(self, 'heads') else 1, + seq_len, seq_len, device=x.device, dtype=x.dtype + ) + return attn_output, self.last_attn_weights + return attn_output @@ -145,13 +202,30 @@ def forward( coords: Float[Tensor, "batch sequence 2"], attn_mask: Bool[Tensor, "batch sequence sequence"] | None, alibi_mask: Bool[Tensor, "batch sequence sequence"], - ) -> Float[Tensor, "batch sequence proj_feature"]: - for attn, ff in cast(Iterable[tuple[nn.Module, nn.Module]], self.layers): - x_attn = attn(x, coords=coords, attn_mask=attn_mask, alibi_mask=alibi_mask) + return_attention: bool = False, + ) -> Float[Tensor, "batch sequence proj_feature"] | tuple[Float[Tensor, "batch sequence proj_feature"], list[Float[Tensor, "batch heads sequence sequence"]]]: + attention_weights = [] + + for attn, ff in cast(Iterable[tuple[SelfAttention, nn.Module]], self.layers): + if return_attention: + x_attn, attn_weights = attn( + x, + coords=coords, + attn_mask=attn_mask, + alibi_mask=alibi_mask, + return_attention=True + ) + attention_weights.append(attn_weights) + else: + x_attn = attn(x, coords=coords, attn_mask=attn_mask, alibi_mask=alibi_mask) + x = x_attn + x x = ff(x) + x x = self.norm(x) + + if return_attention: + return x, attention_weights return x @@ -240,3 +314,55 @@ def forward( bags = bags[:, 0] return self.mlp_head(bags) + + + def get_attention_maps( + self, + bags: Float[Tensor, "batch tile feature"], + *, + coords: Float[Tensor, "batch tile 2"], + mask: Bool[Tensor, "batch tile"] | None, + ) -> Iterable[Float[Tensor, "batch heads sequence sequence"]]: + """Extract the attention maps from the last layer of the transformer.""" + batch_size, _n_tiles, _n_features = bags.shape + + # Map input sequence to latent space of TransMIL + bags = self.project_features(bags) + + # Prepend a class token to every bag + cls_tokens = repeat(self.class_token, "d -> b 1 d", b=batch_size) + bags = torch.cat([cls_tokens, bags], dim=1) + coords = torch.cat( + [torch.zeros(batch_size, 1, 2).type_as(coords), coords], dim=1 + ) + + # Create necessary masks + if mask is None: + bags, attention_weights = self.transformer( + bags, coords=coords, attn_mask=None, alibi_mask=None, return_attention=True + ) + else: + mask_with_class_token = torch.cat( + [torch.zeros(mask.shape[0], 1).type_as(mask), mask], dim=1 + ) + square_attn_mask = torch.einsum( + "bq,bk->bqk", mask_with_class_token, mask_with_class_token + ) + # Don't allow other tiles to reference the class token + square_attn_mask[:, 1:, 0] = True + + # Don't apply ALiBi to the query, as the coordinates don't make sense here + alibi_mask = torch.zeros_like(square_attn_mask) + alibi_mask[:, 0, :] = True + alibi_mask[:, :, 0] = True + + bags, attention_weights = self.transformer( + bags, + coords=coords, + attn_mask=square_attn_mask, + alibi_mask=alibi_mask, + return_attention=True + ) + + # Return the attention weights + return attention_weights From 485b75e1407075a1c4319156d227ff2ebc618e46 Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Fri, 9 May 2025 18:56:15 +0200 Subject: [PATCH 02/10] added attention UI --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 16e06ffc..ba33f4b9 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ positional arguments: statistics Generate AUROCs and AUPRCs with 95%CI for a trained Vision Transformer model config Print the loaded configuration heatmaps Generate heatmaps for a trained model + attention_ui Provides an interactive UI for exploring attention maps options: -h, --help show this help message and exit From 335199e28f726dd0307c30aeddbf9ce321419b51 Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Tue, 13 May 2025 15:39:16 +0200 Subject: [PATCH 03/10] PR update: made attentionui optional, unified attentionui naming, added missing dependency response, added smoke test --- README.md | 1 - pyproject.toml | 6 +- src/stamp/__main__.py | 24 ++--- src/stamp/config.py | 2 +- src/stamp/config.yaml | 2 +- src/stamp/heatmaps/__init__.py | 13 ++- src/stamp/heatmaps/attention_ui.py | 139 ++++++++++++------------- tests/test_attention_ui_integration.py | 45 ++++++++ 8 files changed, 140 insertions(+), 92 deletions(-) create mode 100644 tests/test_attention_ui_integration.py diff --git a/README.md b/README.md index ba33f4b9..16e06ffc 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,6 @@ positional arguments: statistics Generate AUROCs and AUPRCs with 95%CI for a trained Vision Transformer model config Print the loaded configuration heatmaps Generate heatmaps for a trained model - attention_ui Provides an interactive UI for exploring attention maps options: -h, --help show this help message and exit diff --git a/pyproject.toml b/pyproject.toml index e80e9ec0..e8271a6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "jaxtyping>=0.2.36", "lightning>=2.4.0", "matplotlib>=3.9.2", - "napari>=0.6.0", "numpy>=2.2.2", "opencv-python>=4.10.0.84", "openpyxl>=3.1.5", @@ -36,7 +35,6 @@ dependencies = [ "pandas>=2.2.3", "pillow>=11.1.0", "pydantic>=2.10.3", - "pyqt5>=5.15.11", "pyyaml>=6.0.2", "scikit-learn>=1.5.2", "scipy>=1.15.1", @@ -95,6 +93,10 @@ virchow2 = [ "timm>=0.9.11", "torch>=2.0.0", ] +attentionui =[ + "napari>=0.6.0", + "pyqt5>=5.15.11", +] # Blanket target all = ["stamp[dinobloom,conch,ctranspath,uni,virchow2]"] diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 70eaef97..073e99bc 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -210,25 +210,25 @@ def _run_cli(args: argparse.Namespace) -> None: default_slide_mpp=config.heatmaps.default_slide_mpp, ) - case "attention_ui": + case "attentionui": from stamp.heatmaps import attention_ui_ - if config.attention_ui is None: + if config.attentionui is None: raise ValueError("no attention configuration supplied") - _add_file_handle_(_logger, output_dir=config.attention_ui.output_dir) + _add_file_handle_(_logger, output_dir=config.attentionui.output_dir) _logger.info( "using the following configuration:\n" - f"{yaml.dump(config.attention_ui.model_dump(mode='json'))}" + f"{yaml.dump(config.attentionui.model_dump(mode='json'))}" ) attention_ui_( - feature_dir=config.attention_ui.feature_dir, - wsi_dir=config.attention_ui.wsi_dir, - checkpoint_path=config.attention_ui.checkpoint_path, - output_dir=config.attention_ui.output_dir, - slide_paths=config.attention_ui.slide_paths, - device=config.attention_ui.device, - default_slide_mpp=config.attention_ui.default_slide_mpp, + feature_dir=config.attentionui.feature_dir, + wsi_dir=config.attentionui.wsi_dir, + checkpoint_path=config.attentionui.checkpoint_path, + output_dir=config.attentionui.output_dir, + slide_paths=config.attentionui.slide_paths, + device=config.attentionui.device, + default_slide_mpp=config.attentionui.default_slide_mpp, ) case _: @@ -282,7 +282,7 @@ def main() -> None: ) commands.add_parser("config", help="Print the loaded configuration") commands.add_parser("heatmaps", help="Generate heatmaps for a trained model") - commands.add_parser("attention_ui", help="Provides an interactive UI for exploring attention maps") + commands.add_parser("attentionui", help="Provides an interactive UI for exploring attention maps") args = parser.parse_args() diff --git a/src/stamp/config.py b/src/stamp/config.py index 902e5d36..688f861b 100644 --- a/src/stamp/config.py +++ b/src/stamp/config.py @@ -18,4 +18,4 @@ class StampConfig(BaseModel): statistics: StatsConfig | None = None heatmaps: HeatmapConfig | None = None - attention_ui: AttentionUIConfig | None = None \ No newline at end of file + attentionui: AttentionUIConfig | None = None \ No newline at end of file diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index d6eadfd7..f45428a7 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -220,7 +220,7 @@ heatmaps: #bottomk: 5 -attention_ui: +attentionui: output_dir: "/path/to/save/files/to" # Directory the extracted features are saved in. diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 44ecacb8..f1b569a3 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -8,7 +8,6 @@ import numpy as np import openslide import torch -import napari from jaxtyping import Float, Integer from matplotlib.axes import Axes from matplotlib.patches import Patch @@ -22,7 +21,6 @@ from stamp.modeling.vision_transformer import VisionTransformer from stamp.preprocessing import supported_extensions from stamp.preprocessing.tiling import Microns, SlideMPP, TilePixels, get_slide_mpp_ -from stamp.heatmaps.attention_ui import show_attention_ui _logger = logging.getLogger("stamp") @@ -332,6 +330,14 @@ def attention_ui_( device: DeviceLikeType, default_slide_mpp: SlideMPP | None ) -> None: + + try: + from stamp.heatmaps.attention_ui import show_attention_ui + except ImportError as e: + raise ImportError( + "Attention UI dependencies not installed. " + "Please reinstall stamp using `pip install 'stamp[attentionui]'`" + ) from e with torch.no_grad(): @@ -356,5 +362,4 @@ def attention_ui_( # Launch the UI - viewer = show_attention_ui(feature_dir, wsis_to_process, checkpoint_path, output_dir, slide_paths, device, default_slide_mpp) - napari.run() + show_attention_ui(feature_dir, wsis_to_process, checkpoint_path, output_dir, slide_paths, device, default_slide_mpp) diff --git a/src/stamp/heatmaps/attention_ui.py b/src/stamp/heatmaps/attention_ui.py index 1b8e6e2f..76938b85 100644 --- a/src/stamp/heatmaps/attention_ui.py +++ b/src/stamp/heatmaps/attention_ui.py @@ -1,4 +1,11 @@ import napari +from qtpy.QtGui import QPixmap, QImage +from qtpy.QtWidgets import ( + QComboBox, QPushButton, QVBoxLayout, QWidget, QLabel, QScrollArea, + QListWidget, QApplication, QSlider, QHBoxLayout, QFrame +) +from qtpy.QtCore import Qt + import torch import time from torch import Tensor @@ -13,12 +20,6 @@ import h5py from scipy.spatial.distance import cdist from typing import Union -from qtpy.QtGui import QPixmap, QImage -from qtpy.QtWidgets import ( - QComboBox, QPushButton, QVBoxLayout, QWidget, QLabel, QScrollArea, - QListWidget, QApplication, QSlider, QHBoxLayout, QFrame -) -from qtpy.QtCore import Qt from stamp.modeling.lightning_model import LitVisionTransformer from stamp.preprocessing.tiling import Microns, SlideMPP, SlidePixels, get_slide_mpp_ from stamp.modeling.data import get_coords, get_stride @@ -26,6 +27,9 @@ # Define a generic integer type that can be either Python int or numpy int64 IntType = Union[int, np.int64, np.int32] +__author__ = "Dennis Eschweiler" +__copyright__ = "Copyright (C) 2025 Dennis Eschweiler" +__license__ = "MIT" def _vals_to_im( scores: Float[Tensor, "tile feat"], @@ -132,10 +136,7 @@ def __init__( self.image, name='Image' ) - if len(self.image.shape) == 3 and self.image.shape[2] in [3, 4]: # RGB or RGBA - self.height, self.width = self.image.shape[0], self.image.shape[1] - else: # Grayscale - self.height, self.width = self.image.shape + self.height, self.width = self.image.shape[0], self.image.shape[1] # Initialize other attributes self.slide = None @@ -477,9 +478,10 @@ def _set_ui_enabled(self, enabled: bool): def _adjust_slider(self, value=1, ui_element=None): """Adjust the slider value by a given amount""" - current = ui_element.value() - if (value>0 and current < ui_element.maximum()) or (value<0 and current > ui_element.minimum()): - ui_element.setValue(current + value) + if not ui_element is None: + current = ui_element.value() + if (value>0 and current < ui_element.maximum()) or (value<0 and current > ui_element.minimum()): + ui_element.setValue(current + value) def _update_viewer_image(self, new_image: np.ndarray): @@ -523,7 +525,6 @@ def _update_viewer_image(self, new_image: np.ndarray): print(f"Viewer updated with new image") - ### FILE PROCESSING ### @@ -582,7 +583,7 @@ def process_selected_file(self, wsi_path): if h5.attrs.get("unit") == "um": self.tile_size_slide_px = SlidePixels( - int(round(cast(float, h5.attrs["tile_size_um"]) / slide_mpp)) + int(round(cast(float, h5.attrs["tile_size"]) / slide_mpp)) ) else: self.tile_size_slide_px = SlidePixels(int(round(256 / slide_mpp))) @@ -621,62 +622,63 @@ def process_selected_file(self, wsi_path): def load_selected_attention_map(self): - # Get attention weights - # Choose layer - self.attention_map = self.attention_weights[self.num_layer] # Shape: [batch, heads, tokens, tokens]) - # Choose head (or average) - if self.num_head == -1: - # Average over heads - self.attention_map = self.attention_map.mean(dim=1) - else: - self.attention_map = self.attention_map[:,self.num_head,...] # Shape: [batch, tokens, tokens] - # Cut out batch dimension - self.attention_map = self.attention_map[0,...] # Shape: [tokens, tokens] + if not self.attention_weights is None: + # Get attention weights + # Choose layer + self.attention_map = self.attention_weights[self.num_layer] # Shape: [batch, heads, tokens, tokens]) + # Choose head (or average) + if self.num_head == -1: + # Average over heads + self.attention_map = self.attention_map.mean(dim=1) + else: + self.attention_map = self.attention_map[:,self.num_head,...] # Shape: [batch, tokens, tokens] + # Cut out batch dimension + self.attention_map = self.attention_map[0,...] # Shape: [tokens, tokens] - # Normalize attention map to [0, 1] by using percentiles (not considering cls token) - percentile_low = np.percentile(self.attention_map[1:,1:], 0.5) - percentile_high = np.percentile(self.attention_map[1:,1:], 99.5) - self.attention_map = (self.attention_map - percentile_low) / (percentile_high - percentile_low + 1e-8) + # Normalize attention map to [0, 1] by using percentiles (not considering cls token) + percentile_low = np.percentile(self.attention_map[1:,1:], 0.5) + percentile_high = np.percentile(self.attention_map[1:,1:], 99.5) + self.attention_map = (self.attention_map - percentile_low) / (percentile_high - percentile_low + 1e-8) def highlight_top_k_tiles(self): - if self.selected_token_idx is None: - print("No token selected. Click on the image first.") - return - - # Create a new highlight mask - k = min(self.topk_slider.value(), len(self.token_attn)) - highlight_mask = np.zeros((self.height, self.width, 4), dtype=float) + if not self.selected_token_idx is None and\ + not self.token_attn is None and\ + not self.map_coords is None: + + # Create a new highlight mask + k = min(self.topk_slider.value(), len(self.token_attn)) + highlight_mask = np.zeros((self.height, self.width, 4), dtype=float) - if k > 0: - # Get top k indices with highest attention - top_k_values, top_k_indices = torch.topk(self.token_attn, k) - - # For each top tile, add a colored rectangle to the mask - for i, (score, idx) in enumerate(zip(top_k_values.cpu().numpy(), top_k_indices.cpu().numpy())): - # Get tile coordinates - x, y = self.map_coords[idx].cpu().numpy() - - # Convert to image coordinates (scaled by 8) - x_img, y_img = x * 8, y * 8 - tile_size = 8 # Assuming 8x8 pixels per tile - - # Create rectangular highlight for this tile - # Use a different color intensity based on rank (1st is most intense) - min_opacity = 0.5 - intensity = 1.0 - (i * min_opacity / k) # Decreasing intensity for lower ranks + if k > 0: + # Get top k indices with highest attention + top_k_values, top_k_indices = torch.topk(self.token_attn, k) - # Define rectangle in the highlight mask - y_start, y_end = max(0, y_img), min(self.height, y_img + tile_size) - x_start, x_end = max(0, x_img), min(self.width, x_img + tile_size) - - # Red with alpha based on score - highlight_mask[y_start:y_end, x_start:x_end] = [0.0, 0.6, 0.0, min(min_opacity, intensity + score * 0.3)] - - self.highlight_mask = highlight_mask - self.highlight_layer.data = self.highlight_mask + # For each top tile, add a colored rectangle to the mask + for i, (score, idx) in enumerate(zip(top_k_values.cpu().numpy(), top_k_indices.cpu().numpy())): + # Get tile coordinates + x, y = self.map_coords[idx].cpu().numpy() + + # Convert to image coordinates (scaled by 8) + x_img, y_img = x * 8, y * 8 + tile_size = 8 # Assuming 8x8 pixels per tile + + # Create rectangular highlight for this tile + # Use a different color intensity based on rank (1st is most intense) + min_opacity = 0.5 + intensity = 1.0 - (i * min_opacity / k) # Decreasing intensity for lower ranks + + # Define rectangle in the highlight mask + y_start, y_end = max(0, y_img), min(self.height, y_img + tile_size) + x_start, x_end = max(0, x_img), min(self.width, x_img + tile_size) + + # Red with alpha based on score + highlight_mask[y_start:y_end, x_start:x_end] = [0.0, 0.6, 0.0, min(min_opacity, intensity + score * 0.3)] + + self.highlight_mask = highlight_mask + self.highlight_layer.data = self.highlight_mask @@ -943,7 +945,7 @@ def show_attention_ui( slide_paths: Iterable[Path] | None, device: DeviceLikeType, default_slide_mpp: SlideMPP | None - ) -> AttentionViewer: + ): """ Launch the attention UI. @@ -963,13 +965,8 @@ def show_attention_ui( Device to run model on default_slide_mpp : SlideMPP | None Default slide microns per pixel - - Returns: - -------- - AttentionViewer - The viewer instance for attention exploration """ - viewer = AttentionViewer( + AttentionViewer( feature_dir, wsis_to_process, checkpoint_path, @@ -978,4 +975,4 @@ def show_attention_ui( device, default_slide_mpp ) - return viewer \ No newline at end of file + napari.run() \ No newline at end of file diff --git a/tests/test_attention_ui_integration.py b/tests/test_attention_ui_integration.py new file mode 100644 index 00000000..0bfc01ff --- /dev/null +++ b/tests/test_attention_ui_integration.py @@ -0,0 +1,45 @@ +from pathlib import Path + +import pytest +import torch + +from stamp.cache import download_file +from stamp.heatmaps import attention_ui_ + + +@pytest.mark.filterwarnings("ignore:There is a performance drop") +def test_heatmap_integration(tmp_path: Path) -> None: + example_chekpoint_path = download_file( + url="https://github.com/KatherLab/STAMP/releases/download/2.0.0-dev8/example-model.ckpt", + file_name="example-model.ckpt", + sha256sum="a71dffd4b5fdb82acd5f84064880efd3382e200b07e5a008cb53e03197b6de56", + ) + example_slide_path = download_file( + url="https://github.com/KatherLab/STAMP/releases/download/2.0.0.dev14/TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.svs", + file_name="TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.svs", + sha256sum="9b7d2b0294524351bf29229c656cc886af028cb9e7463882289fac43c1347525", + ) + example_feature_path = download_file( + url="https://github.com/KatherLab/STAMP/releases/download/2.0.0.dev14/TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4-mahmood-uni.h5", + file_name="TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4-mahmood-uni.h5", + sha256sum="13b1390241e73a3969915d3d01c5c64f1b7c68318a685d8e3bf851067162f0bc", + ) + + wsi_dir = tmp_path / "wsis" + wsi_dir.mkdir() + (wsi_dir / "slide.svs").symlink_to(example_slide_path) + feature_dir = tmp_path / "feats" + feature_dir.mkdir() + (feature_dir / "slide.h5").symlink_to(example_feature_path) + + attention_ui_( + feature_dir=feature_dir, + wsi_dir=wsi_dir, + checkpoint_path=example_chekpoint_path, + output_dir=tmp_path / "output", + slide_paths=None, + device="cuda" if torch.cuda.is_available() else "cpu", + default_slide_mpp=None, + ) + + # For now this remains a simple smoke test, as no output is generated \ No newline at end of file From d98b7d30ce12456d17435ea936b9e46fdb4e4974 Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Tue, 13 May 2025 17:33:01 +0200 Subject: [PATCH 04/10] removed smoke test --- tests/test_attention_ui_integration.py | 45 -------------------------- 1 file changed, 45 deletions(-) delete mode 100644 tests/test_attention_ui_integration.py diff --git a/tests/test_attention_ui_integration.py b/tests/test_attention_ui_integration.py deleted file mode 100644 index 0bfc01ff..00000000 --- a/tests/test_attention_ui_integration.py +++ /dev/null @@ -1,45 +0,0 @@ -from pathlib import Path - -import pytest -import torch - -from stamp.cache import download_file -from stamp.heatmaps import attention_ui_ - - -@pytest.mark.filterwarnings("ignore:There is a performance drop") -def test_heatmap_integration(tmp_path: Path) -> None: - example_chekpoint_path = download_file( - url="https://github.com/KatherLab/STAMP/releases/download/2.0.0-dev8/example-model.ckpt", - file_name="example-model.ckpt", - sha256sum="a71dffd4b5fdb82acd5f84064880efd3382e200b07e5a008cb53e03197b6de56", - ) - example_slide_path = download_file( - url="https://github.com/KatherLab/STAMP/releases/download/2.0.0.dev14/TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.svs", - file_name="TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4.svs", - sha256sum="9b7d2b0294524351bf29229c656cc886af028cb9e7463882289fac43c1347525", - ) - example_feature_path = download_file( - url="https://github.com/KatherLab/STAMP/releases/download/2.0.0.dev14/TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4-mahmood-uni.h5", - file_name="TCGA-G4-6625-01Z-00-DX1.0fa26667-2581-4f96-a891-d78dbc3299b4-mahmood-uni.h5", - sha256sum="13b1390241e73a3969915d3d01c5c64f1b7c68318a685d8e3bf851067162f0bc", - ) - - wsi_dir = tmp_path / "wsis" - wsi_dir.mkdir() - (wsi_dir / "slide.svs").symlink_to(example_slide_path) - feature_dir = tmp_path / "feats" - feature_dir.mkdir() - (feature_dir / "slide.h5").symlink_to(example_feature_path) - - attention_ui_( - feature_dir=feature_dir, - wsi_dir=wsi_dir, - checkpoint_path=example_chekpoint_path, - output_dir=tmp_path / "output", - slide_paths=None, - device="cuda" if torch.cuda.is_available() else "cpu", - default_slide_mpp=None, - ) - - # For now this remains a simple smoke test, as no output is generated \ No newline at end of file From 9b8e7fbe7b5aea0c61cc7cf5f0c03323ce60d868 Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Wed, 28 May 2025 14:07:17 +0200 Subject: [PATCH 05/10] fixed code style issues --- src/stamp/heatmaps/attention_ui.py | 14 +++++++------- src/stamp/modeling/vision_transformer.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/stamp/heatmaps/attention_ui.py b/src/stamp/heatmaps/attention_ui.py index 76938b85..ad0415ae 100644 --- a/src/stamp/heatmaps/attention_ui.py +++ b/src/stamp/heatmaps/attention_ui.py @@ -478,7 +478,7 @@ def _set_ui_enabled(self, enabled: bool): def _adjust_slider(self, value=1, ui_element=None): """Adjust the slider value by a given amount""" - if not ui_element is None: + if ui_element is not None: current = ui_element.value() if (value>0 and current < ui_element.maximum()) or (value<0 and current > ui_element.minimum()): ui_element.setValue(current + value) @@ -522,7 +522,7 @@ def _update_viewer_image(self, new_image: np.ndarray): self.viewer.layers.selection.active = self.points_layer self.points_layer.mode = 'add' - print(f"Viewer updated with new image") + print("Viewer updated with new image") @@ -622,7 +622,7 @@ def process_selected_file(self, wsi_path): def load_selected_attention_map(self): - if not self.attention_weights is None: + if self.attention_weights is not None: # Get attention weights # Choose layer self.attention_map = self.attention_weights[self.num_layer] # Shape: [batch, heads, tokens, tokens]) @@ -644,9 +644,9 @@ def load_selected_attention_map(self): def highlight_top_k_tiles(self): - if not self.selected_token_idx is None and\ - not self.token_attn is None and\ - not self.map_coords is None: + if self.selected_token_idx is not None and\ + self.token_attn is not None and\ + self.map_coords is not None: # Create a new highlight mask k = min(self.topk_slider.value(), len(self.token_attn)) @@ -787,7 +787,7 @@ def _on_update_attention_map(self): def _on_add_point(self): """Handle points being added to the points layer""" - if not self.map_coords is None: + if self.map_coords is not None: # Prevent recursive calls if self._updating_points: return diff --git a/src/stamp/modeling/vision_transformer.py b/src/stamp/modeling/vision_transformer.py index bd6d87f1..5b035b8a 100755 --- a/src/stamp/modeling/vision_transformer.py +++ b/src/stamp/modeling/vision_transformer.py @@ -112,7 +112,7 @@ def forward( return_attention=True, ) self.last_attn_weights = attn_weights - except: + except (TypeError, ValueError, RuntimeError) as e: # If the return_attention param exists but fails, fall back attn_output = self.mhsa( q=x, @@ -125,7 +125,7 @@ def forward( ) # Create dummy attention weights to satisfy type checking if return_attention: - print("Warning: Failed to return attention weights. Creating dummy weights.") + print(f"Warning: Failed to return attention weights ({type(e).__name__}: {e}). Creating dummy weights.") batch_size, seq_len, _ = x.shape self.last_attn_weights = torch.zeros( batch_size, self.heads, seq_len, seq_len, From fab78c491db9608d29a2b0ce3729924ec6f90521 Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Mon, 2 Jun 2025 09:39:19 +0200 Subject: [PATCH 06/10] ruff formatting --- src/stamp/__main__.py | 4 +- src/stamp/config.py | 2 +- src/stamp/heatmaps/__init__.py | 23 +- src/stamp/heatmaps/attention_ui.py | 527 ++++++++++++----------- src/stamp/modeling/vision_transformer.py | 69 ++- 5 files changed, 345 insertions(+), 280 deletions(-) diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 073e99bc..952f3809 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -282,7 +282,9 @@ def main() -> None: ) commands.add_parser("config", help="Print the loaded configuration") commands.add_parser("heatmaps", help="Generate heatmaps for a trained model") - commands.add_parser("attentionui", help="Provides an interactive UI for exploring attention maps") + commands.add_parser( + "attentionui", help="Provides an interactive UI for exploring attention maps" + ) args = parser.parse_args() diff --git a/src/stamp/config.py b/src/stamp/config.py index 688f861b..dce00de5 100644 --- a/src/stamp/config.py +++ b/src/stamp/config.py @@ -18,4 +18,4 @@ class StampConfig(BaseModel): statistics: StatsConfig | None = None heatmaps: HeatmapConfig | None = None - attentionui: AttentionUIConfig | None = None \ No newline at end of file + attentionui: AttentionUIConfig | None = None diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index f1b569a3..3db88d23 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -318,8 +318,6 @@ def heatmaps_( plt.close(fig) - - def attention_ui_( *, feature_dir: Path, @@ -328,9 +326,8 @@ def attention_ui_( output_dir: Path, slide_paths: Iterable[Path] | None, device: DeviceLikeType, - default_slide_mpp: SlideMPP | None + default_slide_mpp: SlideMPP | None, ) -> None: - try: from stamp.heatmaps.attention_ui import show_attention_ui except ImportError as e: @@ -340,14 +337,13 @@ def attention_ui_( ) from e with torch.no_grad(): - # Collect slides to generate attention maps for if slide_paths is not None: wsis_to_process_all = (wsi_dir / slide for slide in slide_paths) else: wsis_to_process_all = ( p for ext in supported_extensions for p in wsi_dir.glob(f"**/*{ext}") - ) + ) # Check of a corresponding feature file exists wsis_to_process = [] @@ -355,11 +351,20 @@ def attention_ui_( h5_path = feature_dir / wsi_path.with_suffix(".h5").name if not h5_path.exists(): - _logger.info(f"could not find matching h5 file at {h5_path}. Skipping...") + _logger.info( + f"could not find matching h5 file at {h5_path}. Skipping..." + ) continue wsis_to_process.append(str(wsi_path)) - # Launch the UI - show_attention_ui(feature_dir, wsis_to_process, checkpoint_path, output_dir, slide_paths, device, default_slide_mpp) + show_attention_ui( + feature_dir, + wsis_to_process, + checkpoint_path, + output_dir, + slide_paths, + device, + default_slide_mpp, + ) diff --git a/src/stamp/heatmaps/attention_ui.py b/src/stamp/heatmaps/attention_ui.py index ad0415ae..99e4db73 100644 --- a/src/stamp/heatmaps/attention_ui.py +++ b/src/stamp/heatmaps/attention_ui.py @@ -1,8 +1,17 @@ import napari from qtpy.QtGui import QPixmap, QImage from qtpy.QtWidgets import ( - QComboBox, QPushButton, QVBoxLayout, QWidget, QLabel, QScrollArea, - QListWidget, QApplication, QSlider, QHBoxLayout, QFrame + QComboBox, + QPushButton, + QVBoxLayout, + QWidget, + QLabel, + QScrollArea, + QListWidget, + QApplication, + QSlider, + QHBoxLayout, + QFrame, ) from qtpy.QtCore import Qt @@ -31,13 +40,14 @@ __copyright__ = "Copyright (C) 2025 Dennis Eschweiler" __license__ = "MIT" + def _vals_to_im( scores: Float[Tensor, "tile feat"], coords_norm: Integer[Tensor, "tile coord"], ) -> Float[Tensor, "width height category"]: """Arranges scores in a 2d grid according to coordinates""" size = coords_norm.max(0).values.flip(0) + 1 - im = torch.ones((*size.tolist(), *scores.shape[1:])).type_as(scores)*-1e-8 + im = torch.ones((*size.tolist(), *scores.shape[1:])).type_as(scores) * -1e-8 flattened_im = im.flatten(end_dim=1) flattened_coords = coords_norm[:, 1] * im.shape[1] + coords_norm[:, 0] @@ -54,16 +64,16 @@ def _get_thumb(slide, slide_mpp: SlideMPP) -> np.ndarray: dims_um = np.array(slide.dimensions) * slide_mpp thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)) thumb = np.array(thumb) - return thumb/255 + return thumb / 255 def _patch_to_pixmap(patch_image): """Convert patch image to QPixmap for display in QLabel using NumPy""" - + # Resize for better visualization if needed display_size = 200 # Adjust this value as needed aspect_ratio = patch_image.width / patch_image.height - + if aspect_ratio > 1: # Wider than tall new_width = display_size @@ -72,42 +82,38 @@ def _patch_to_pixmap(patch_image): # Taller than wide new_height = display_size new_width = int(display_size * aspect_ratio) - + patch_image = patch_image.resize((new_width, new_height)) - + # Convert PIL image to numpy array img_array = np.array(patch_image) - + # Create QImage from NumPy array height, width, channels = img_array.shape bytes_per_line = channels * width - + # Create QImage (RGB format) q_image = QImage( - img_array.data, - width, - height, - bytes_per_line, - QImage.Format_RGB888 + img_array.data, width, height, bytes_per_line, QImage.Format_RGB888 ) - + return QPixmap.fromImage(q_image) class AttentionViewer: def __init__( - self, + self, feature_dir: Path, wsis_to_process: Iterable[str], checkpoint_path: Path, output_dir: Path, slide_paths: Iterable[Path] | None, device: DeviceLikeType, - default_slide_mpp: SlideMPP | None + default_slide_mpp: SlideMPP | None, ): """ Interactive viewer for images with click-based heatmap generation. - + Parameters: ----------- image : np.ndarray @@ -125,22 +131,25 @@ def __init__( self.default_slide_mpp = default_slide_mpp # Initialize model - self.model = LitVisionTransformer.load_from_checkpoint(checkpoint_path).to(device).eval() + self.model = ( + LitVisionTransformer.load_from_checkpoint(checkpoint_path).to(device).eval() + ) # Initialize napari viewer - self.viewer = napari.Viewer(title=' Histopathology Attention Explorer') + self.viewer = napari.Viewer(title=" Histopathology Attention Explorer") # Create dummy image layer - self.image = np.zeros((100,100,3), dtype=np.float32) # Placeholder for the image - self.image_layer = self.viewer.add_image( - self.image, - name='Image' - ) + self.image = np.zeros( + (100, 100, 3), dtype=np.float32 + ) # Placeholder for the image + self.image_layer = self.viewer.add_image(self.image, name="Image") self.height, self.width = self.image.shape[0], self.image.shape[1] # Initialize other attributes self.slide = None - self.attention_map = np.zeros((100,100), dtype=np.float32) # Placeholder for the attention map + self.attention_map = np.zeros( + (100, 100), dtype=np.float32 + ) # Placeholder for the attention map self.attention_weights = None self.token_attn = None self._attention_update_debounce = 500 @@ -152,36 +161,31 @@ def __init__( self.selected_filename = None self.num_layer = 0 self.num_head = 0 - - + # Initialize empty heatmap self.heatmap = np.zeros((self.height, self.width, 4), dtype=float) self.heatmap_layer = self.viewer.add_image( - self.heatmap, - name='Attention', - rgb=True, - visible=True, - opacity=0.7 - ) + self.heatmap, name="Attention", rgb=True, visible=True, opacity=0.7 + ) # Initialize empty highlight map self.highlight_mask = np.zeros((self.height, self.width, 4), dtype=float) self.highlight_layer = self.viewer.add_image( self.highlight_mask, - name='Top-k Highlight', + name="Top-k Highlight", rgb=True, visible=True, - opacity=1.0 - ) + opacity=1.0, + ) - # Initialize points layer + # Initialize points layer self.clicked_points = [] self.points_layer = self.viewer.add_points( - name='Selected Point', + name="Selected Point", size=10, - face_color='green', - symbol='x', - n_dimensional=True + face_color="green", + symbol="x", + n_dimensional=True, ) self._last_processed_point_count = 0 self._updating_points = False @@ -194,12 +198,10 @@ def __init__( # Disable UI elements until a file is selected self._set_ui_enabled(False) - + # Print instructions print("Click on the image to select points and generate attention heatmaps") - - #### ADDING UI ELEMENTS #### def _add_file_selection_ui(self): @@ -208,38 +210,34 @@ def _add_file_selection_ui(self): file_selection_widget = QWidget() layout = QVBoxLayout() file_selection_widget.setLayout(layout) - + # Add a label label = QLabel("Available files:") layout.addWidget(label) - - # Create a list widget for file selection (scrollable) + + # Create a list widget for file selection (scrollable) self.file_list = QListWidget() self.file_list.addItems(self.wsis_to_process) - + # Set a reasonable height to show multiple items self.file_list.setMinimumHeight(150) - + # Make it scrollable if many items self.file_list.setVerticalScrollBarPolicy(1) # 1 = Always show scrollbar - + # Add the list to the layout layout.addWidget(self.file_list) - + # Create a process button self.process_button = QPushButton("Process Selected File") self.process_button.clicked.connect(self._on_process_file) layout.addWidget(self.process_button) - + # Add the widget to napari viewer as a dock widget self.viewer.window.add_dock_widget( - file_selection_widget, - name="File Selection", - area="right" + file_selection_widget, name="File Selection", area="right" ) - - def _add_config_selection_ui(self): """Add UI controls for selecting attention layer and head""" # Create a widget container @@ -267,19 +265,20 @@ def _add_config_selection_ui(self): self.attention_handling.addItem("Class token attention", 6) # Connect the dropdown to the update function - self.attention_handling.currentIndexChanged.connect(self._on_update_attention_map) + self.attention_handling.currentIndexChanged.connect( + self._on_update_attention_map + ) direction_layout.addWidget(self.attention_handling) layout.addLayout(direction_layout) - # === LAYER SELECTION === layer_label = QLabel("Number of Network Layer\n(first->last):") layout.addWidget(layer_label) - + # Create layer selection controls with arrows and slider layer_layout = QHBoxLayout() - + # Layer slider self.layer_slider = QSlider(Qt.Horizontal) self.layer_slider.setMinimum(0) @@ -288,38 +287,42 @@ def _add_config_selection_ui(self): self.layer_slider.valueChanged.connect( lambda value: ( self.layer_value_label.setText(str(value)), - self._on_update_attention_map() + self._on_update_attention_map(), ) ) layer_layout.addWidget(self.layer_slider) - + # Left arrow button for layer self.layer_left_btn = QPushButton("←") self.layer_left_btn.setMaximumWidth(30) - self.layer_left_btn.clicked.connect(lambda: self._adjust_slider(value=-1, ui_element=self.layer_slider)) + self.layer_left_btn.clicked.connect( + lambda: self._adjust_slider(value=-1, ui_element=self.layer_slider) + ) layer_layout.addWidget(self.layer_left_btn) - + # Right arrow button for layer self.layer_right_btn = QPushButton("→") self.layer_right_btn.setMaximumWidth(30) - self.layer_right_btn.clicked.connect(lambda: self._adjust_slider(value=1, ui_element=self.layer_slider)) + self.layer_right_btn.clicked.connect( + lambda: self._adjust_slider(value=1, ui_element=self.layer_slider) + ) layer_layout.addWidget(self.layer_right_btn) - + # Layer value display self.layer_value_label = QLabel("0") self.layer_value_label.setMinimumWidth(25) self.layer_value_label.setAlignment(Qt.AlignCenter) layer_layout.addWidget(self.layer_value_label) - + layout.addLayout(layer_layout) - + # === HEAD SELECTION === head_label = QLabel("Number of Prediction Head\n(-1 for average):") layout.addWidget(head_label) - + # Create head selection controls with arrows and slider head_layout = QHBoxLayout() - + # Head slider self.head_slider = QSlider(Qt.Horizontal) self.head_slider.setMinimum(-1) @@ -328,70 +331,77 @@ def _add_config_selection_ui(self): self.head_slider.valueChanged.connect( lambda value: ( self.head_value_label.setText(str(value)), - self._on_update_attention_map() + self._on_update_attention_map(), ) ) head_layout.addWidget(self.head_slider) - + # Left arrow button for head self.head_left_btn = QPushButton("←") self.head_left_btn.setMaximumWidth(30) - self.head_left_btn.clicked.connect(lambda: self._adjust_slider(-1, self.head_slider)) + self.head_left_btn.clicked.connect( + lambda: self._adjust_slider(-1, self.head_slider) + ) head_layout.addWidget(self.head_left_btn) - + # Right arrow button for head self.head_right_btn = QPushButton("→") self.head_right_btn.setMaximumWidth(30) - self.head_right_btn.clicked.connect(lambda: self._adjust_slider(1, self.head_slider)) + self.head_right_btn.clicked.connect( + lambda: self._adjust_slider(1, self.head_slider) + ) head_layout.addWidget(self.head_right_btn) - + # Head value display self.head_value_label = QLabel("0") self.head_value_label.setMinimumWidth(25) self.head_value_label.setAlignment(Qt.AlignCenter) head_layout.addWidget(self.head_value_label) - - layout.addLayout(head_layout) + layout.addLayout(head_layout) # === Top-k SELECTION === topk_label = QLabel("Top-k tiles to highlight:") layout.addWidget(topk_label) - + # Create top-k selection controls with arrows and slider topk_layout = QHBoxLayout() - + # Top-k slider self.topk_slider = QSlider(Qt.Horizontal) self.topk_slider.setMinimum(0) - self.topk_slider.setMaximum(50) + self.topk_slider.setMaximum(50) self.topk_slider.setValue(5) self.topk_slider.valueChanged.connect( lambda value: ( self.topk_value_label.setText(str(value)), - self._on_update_attention_map() + self._on_update_attention_map(), ) ) topk_layout.addWidget(self.topk_slider) - + # Left arrow button for top-k self.topk_left_btn = QPushButton("←") self.topk_left_btn.setMaximumWidth(30) - self.topk_left_btn.clicked.connect(lambda: self._adjust_slider(-1, self.topk_slider)) + self.topk_left_btn.clicked.connect( + lambda: self._adjust_slider(-1, self.topk_slider) + ) topk_layout.addWidget(self.topk_left_btn) - + # Right arrow button for top-k self.topk_right_btn = QPushButton("→") self.topk_right_btn.setMaximumWidth(30) - self.topk_right_btn.clicked.connect(lambda: self._adjust_slider(1, self.topk_slider)) + self.topk_right_btn.clicked.connect( + lambda: self._adjust_slider(1, self.topk_slider) + ) topk_layout.addWidget(self.topk_right_btn) - + # Top-k value display self.topk_value_label = QLabel("5") self.topk_value_label.setMinimumWidth(25) self.topk_value_label.setAlignment(Qt.AlignCenter) topk_layout.addWidget(self.topk_value_label) - + layout.addLayout(topk_layout) # Add a horizontal separator @@ -399,95 +409,94 @@ def _add_config_selection_ui(self): separator.setFrameShape(QFrame.HLine) separator.setFrameShadow(QFrame.Sunken) layout.addWidget(separator) - + # === LOAD PATCHES BUTTON === self.load_patches_btn = QPushButton("Load Selected & Top-k Patches") self.load_patches_btn.clicked.connect(self._load_tile_patches) - layout.addWidget(self.load_patches_btn) + layout.addWidget(self.load_patches_btn) # === Add the widget to napari viewer as a dock widget === self.viewer.window.add_dock_widget( - selection_widget, - name="Attention Parameters", - area="right" + selection_widget, name="Attention Parameters", area="right" ) - def _add_patch_display_widget(self): """Create a widget to display the selected patch and top-k patches""" # Create main widget self.patch_display_widget = QWidget() layout = QHBoxLayout() self.patch_display_widget.setLayout(layout) - + # Create a scroll area to contain patches scroll_area = QScrollArea() scroll_area.setWidgetResizable(True) scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOn) scroll_area.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOff) - + # Create container widget for patches self.patches_container = QWidget() self.patches_layout = QHBoxLayout() self.patches_container.setLayout(self.patches_layout) - + # Add to scroll area scroll_area.setWidget(self.patches_container) layout.addWidget(scroll_area) - + # Add the widget to napari viewer as a dock widget at the bottom self.viewer.window.add_dock_widget( - self.patch_display_widget, - name="Tile Patches", - area="bottom" + self.patch_display_widget, name="Tile Patches", area="bottom" ) - - ### UI HANDLING ### - + def _set_ui_enabled(self, enabled: bool): """Enable or disable all UI elements""" # Enable points layer - if hasattr(self, 'points_layer'): + if hasattr(self, "points_layer"): self.points_layer.editable = enabled - + # Enable patch loading button - if hasattr(self, 'load_patches_btn'): + if hasattr(self, "load_patches_btn"): self.load_patches_btn.setEnabled(enabled) # Enable attention handling dropdown - if hasattr(self, 'attention_handling'): + if hasattr(self, "attention_handling"): self.attention_handling.setEnabled(enabled) # Enable all sliders - for slider_name in ['layer_slider', 'head_slider', 'topk_slider']: + for slider_name in ["layer_slider", "head_slider", "topk_slider"]: if hasattr(self, slider_name): getattr(self, slider_name).setEnabled(enabled) - + # Enable all arrow buttons - for btn_name in ['layer_left_btn', 'layer_right_btn', 'head_left_btn', - 'head_right_btn', 'topk_left_btn', 'topk_right_btn']: + for btn_name in [ + "layer_left_btn", + "layer_right_btn", + "head_left_btn", + "head_right_btn", + "topk_left_btn", + "topk_right_btn", + ]: if hasattr(self, btn_name): getattr(self, btn_name).setEnabled(enabled) - + # Process UI events to update display QApplication.processEvents() - def _adjust_slider(self, value=1, ui_element=None): """Adjust the slider value by a given amount""" if ui_element is not None: current = ui_element.value() - if (value>0 and current < ui_element.maximum()) or (value<0 and current > ui_element.minimum()): + if (value > 0 and current < ui_element.maximum()) or ( + value < 0 and current > ui_element.minimum() + ): ui_element.setValue(current + value) - def _update_viewer_image(self, new_image: np.ndarray): """ Update the viewer image and reset heatmap after loading a new file - + Parameters: ----------- new_image : np.ndarray @@ -496,78 +505,73 @@ def _update_viewer_image(self, new_image: np.ndarray): # Update the image layer self.image = new_image self.image_layer.data = self.image - + # Update dimensions based on the new image if len(self.image.shape) == 3 and self.image.shape[2] in [3, 4]: # RGB or RGBA self.height, self.width = self.image.shape[0], self.image.shape[1] else: # Grayscale self.height, self.width = self.image.shape - + # Reset the heatmap self.heatmap = np.zeros((self.height, self.width, 4), dtype=float) self.heatmap_layer.data = self.heatmap - + # Clear any existing points self.points_layer.data = np.empty((0, 2)) self.clicked_points = [] self._last_processed_point_count = 0 - + # Reset selected token self.selected_token_idx = None - + # Reset viewer scale and position to fit the new image self.viewer.reset_view() - + # Set active layer to points layer and mode to add self.viewer.layers.selection.active = self.points_layer - self.points_layer.mode = 'add' - - print("Viewer updated with new image") - - + self.points_layer.mode = "add" + print("Viewer updated with new image") ### FILE PROCESSING ### def _on_process_file(self): """Handle the Process button click""" selected_items = self.file_list.selectedItems() - + if selected_items: # Get the selected filename self.selected_filename = selected_items[0].text() print(f"Selected file: {self.selected_filename}") - + self.process_selected_file(self.selected_filename) self.load_selected_attention_map() - + else: print("No file selected! Please select a file first.") - - def process_selected_file(self, wsi_path): """Load the selected file and the corresponding attention map""" - + # Disable UI controls self._set_ui_enabled(False) - + # Use QApplication to process events and update the UI QApplication.processEvents() - try: - + try: print(f"Processing file: {wsi_path}") with torch.no_grad(): - - # Load WSI + # Load WSI wsi_path = Path(wsi_path) h5_path = self.feature_dir / wsi_path.with_suffix(".h5").name print(f"Creating attention map for {wsi_path.name}") self.slide = openslide.open_slide(wsi_path) - slide_mpp = get_slide_mpp_(self.slide, default_mpp=self.default_slide_mpp) + slide_mpp = get_slide_mpp_( + self.slide, default_mpp=self.default_slide_mpp + ) assert slide_mpp is not None, "could not determine slide MPP" with h5py.File(h5_path) as h5: @@ -586,7 +590,9 @@ def process_selected_file(self, wsi_path): int(round(cast(float, h5.attrs["tile_size"]) / slide_mpp)) ) else: - self.tile_size_slide_px = SlidePixels(int(round(256 / slide_mpp))) + self.tile_size_slide_px = SlidePixels( + int(round(256 / slide_mpp)) + ) # grid coordinates, i.e. the top-left most tile is (0, 0), the one to its right (0, 1) etc. self.map_coords = (coords_um / stride_um).round().long() @@ -595,93 +601,104 @@ def process_selected_file(self, wsi_path): self.coords_tile_slide_px = torch.round(coords_um / slide_mpp).long() # Score for the entire slide - self.attention_weights = self.model.vision_transformer.get_attention_maps( - bags=feats.unsqueeze(0), - coords=coords_um.unsqueeze(0), - mask=torch.zeros(1, len(feats), dtype=torch.bool, device=self.device), - ) - + self.attention_weights = ( + self.model.vision_transformer.get_attention_maps( + bags=feats.unsqueeze(0), + coords=coords_um.unsqueeze(0), + mask=torch.zeros( + 1, len(feats), dtype=torch.bool, device=self.device + ), + ) + ) + # Determine number of heads and layers and update UI elements num_layers = len(self.attention_weights) num_heads = self.attention_weights[0].shape[1] self.layer_slider.setMaximum(num_layers - 1) self.head_slider.setMaximum(num_heads - 1) - + # Get thumbnail of the slide self.image = _get_thumb(self.slide, slide_mpp) # Update the viewer with the new image self._update_viewer_image(self.image) - finally: # Re-enable UI controls self._set_ui_enabled(True) - - def load_selected_attention_map(self): - if self.attention_weights is not None: # Get attention weights # Choose layer - self.attention_map = self.attention_weights[self.num_layer] # Shape: [batch, heads, tokens, tokens]) + self.attention_map = self.attention_weights[ + self.num_layer + ] # Shape: [batch, heads, tokens, tokens]) # Choose head (or average) if self.num_head == -1: # Average over heads self.attention_map = self.attention_map.mean(dim=1) else: - self.attention_map = self.attention_map[:,self.num_head,...] # Shape: [batch, tokens, tokens] + self.attention_map = self.attention_map[ + :, self.num_head, ... + ] # Shape: [batch, tokens, tokens] # Cut out batch dimension - self.attention_map = self.attention_map[0,...] # Shape: [tokens, tokens] + self.attention_map = self.attention_map[0, ...] # Shape: [tokens, tokens] # Normalize attention map to [0, 1] by using percentiles (not considering cls token) - percentile_low = np.percentile(self.attention_map[1:,1:], 0.5) - percentile_high = np.percentile(self.attention_map[1:,1:], 99.5) - self.attention_map = (self.attention_map - percentile_low) / (percentile_high - percentile_low + 1e-8) - - + percentile_low = np.percentile(self.attention_map[1:, 1:], 0.5) + percentile_high = np.percentile(self.attention_map[1:, 1:], 99.5) + self.attention_map = (self.attention_map - percentile_low) / ( + percentile_high - percentile_low + 1e-8 + ) def highlight_top_k_tiles(self): - - if self.selected_token_idx is not None and\ - self.token_attn is not None and\ - self.map_coords is not None: - + if ( + self.selected_token_idx is not None + and self.token_attn is not None + and self.map_coords is not None + ): # Create a new highlight mask k = min(self.topk_slider.value(), len(self.token_attn)) highlight_mask = np.zeros((self.height, self.width, 4), dtype=float) - if k > 0: + if k > 0: # Get top k indices with highest attention top_k_values, top_k_indices = torch.topk(self.token_attn, k) - + # For each top tile, add a colored rectangle to the mask - for i, (score, idx) in enumerate(zip(top_k_values.cpu().numpy(), top_k_indices.cpu().numpy())): + for i, (score, idx) in enumerate( + zip(top_k_values.cpu().numpy(), top_k_indices.cpu().numpy()) + ): # Get tile coordinates x, y = self.map_coords[idx].cpu().numpy() - + # Convert to image coordinates (scaled by 8) x_img, y_img = x * 8, y * 8 tile_size = 8 # Assuming 8x8 pixels per tile - + # Create rectangular highlight for this tile # Use a different color intensity based on rank (1st is most intense) min_opacity = 0.5 - intensity = 1.0 - (i * min_opacity / k) # Decreasing intensity for lower ranks - + intensity = 1.0 - ( + i * min_opacity / k + ) # Decreasing intensity for lower ranks + # Define rectangle in the highlight mask y_start, y_end = max(0, y_img), min(self.height, y_img + tile_size) x_start, x_end = max(0, x_img), min(self.width, x_img + tile_size) - + # Red with alpha based on score - highlight_mask[y_start:y_end, x_start:x_end] = [0.0, 0.6, 0.0, min(min_opacity, intensity + score * 0.3)] - + highlight_mask[y_start:y_end, x_start:x_end] = [ + 0.0, + 0.6, + 0.0, + min(min_opacity, intensity + score * 0.3), + ] + self.highlight_mask = highlight_mask self.highlight_layer.data = self.highlight_mask - - def _load_tile_patches(self): """Load and display the selected patch and top-k patches""" if self.selected_token_idx is None: @@ -694,39 +711,39 @@ def _load_tile_patches(self): widget = item.widget() if widget: widget.deleteLater() - + # Get selected token patch selected_patch = self.slide.read_region( tuple(self.coords_tile_slide_px[self.selected_token_idx].tolist()), 0, (self.tile_size_slide_px, self.tile_size_slide_px), ).convert("RGB") - + # Add selected patch with label selected_frame = QFrame() selected_layout = QVBoxLayout() selected_frame.setLayout(selected_layout) - + # Create QLabel for image selected_label = QLabel() selected_pixmap = _patch_to_pixmap(selected_patch) selected_label.setPixmap(selected_pixmap) - + # Create label text text_label = QLabel(f"Selected-ID:{self.selected_token_idx}") text_label.setAlignment(Qt.AlignCenter) - + selected_layout.addWidget(selected_label) selected_layout.addWidget(text_label) self.patches_layout.addWidget(selected_frame) - + # Add separator separator = QFrame() separator.setFrameShape(QFrame.VLine) separator.setFrameShadow(QFrame.Sunken) separator.setLineWidth(2) - separator.setMinimumWidth(5) - separator.setStyleSheet("background-color: #888888;") + separator.setMinimumWidth(5) + separator.setStyleSheet("background-color: #888888;") self.patches_layout.addWidget(separator) # Get top-k patches @@ -739,38 +756,40 @@ def _load_tile_patches(self): 0, (self.tile_size_slide_px, self.tile_size_slide_px), ).convert("RGB") - + # Create frame with layout patch_frame = QFrame() patch_layout = QVBoxLayout() patch_frame.setLayout(patch_layout) - + # Create QLabel for image patch_label = QLabel() patch_pixmap = _patch_to_pixmap(patch) patch_label.setPixmap(patch_pixmap) - + # Create label text - text_label = QLabel(f"Top-{n+1}-ID:{index} (Score:{score:.2f})") + text_label = QLabel(f"Top-{n + 1}-ID:{index} (Score:{score:.2f})") text_label.setAlignment(Qt.AlignCenter) - + patch_layout.addWidget(patch_label) patch_layout.addWidget(text_label) self.patches_layout.addWidget(patch_frame) - + # Force update of the layout self.patches_container.adjustSize() QApplication.processEvents() - def _on_update_attention_map(self): # Check if we have data to display if self.attention_weights is None: return - + # Simple debounce to avoid too frequent updates current_time = time.time() * 1000 # Convert to milliseconds - if current_time - self._last_attention_update_time < self._attention_update_debounce: + if ( + current_time - self._last_attention_update_time + < self._attention_update_debounce + ): return self._last_attention_update_time = current_time @@ -783,40 +802,46 @@ def _on_update_attention_map(self): self._last_processed_point_count = 0 # Reset last processed point count self._on_add_point() - - def _on_add_point(self): """Handle points being added to the points layer""" if self.map_coords is not None: # Prevent recursive calls if self._updating_points: return - + # Check if points have been added - if len(self.points_layer.data) > self._last_processed_point_count: # If there's any data + if ( + len(self.points_layer.data) > self._last_processed_point_count + ): # If there's any data # Keep only the last added point last_point = self.points_layer.data[-1] - + # Convert to proper types y, x = int(last_point[0]), int(last_point[1]) - + # Set the flag before updating to prevent recursion self._updating_points = True - - try: + + try: # Update heatmap based on the new point - self.update_heatmap(y-4, x-4) # 4 to center the point + self.update_heatmap(y - 4, x - 4) # 4 to center the point # Snap to selected token position - x_snapped, y_snapped = self.map_coords[self.selected_token_idx,:].tolist() - self.points_layer.data = np.array([[y_snapped*8+4, x_snapped*8+4]]) + x_snapped, y_snapped = self.map_coords[ + self.selected_token_idx, : + ].tolist() + self.points_layer.data = np.array( + [[y_snapped * 8 + 4, x_snapped * 8 + 4]] + ) self._last_processed_point_count = 1 # We now have 1 point # Update top-k tiles self.highlight_top_k_tiles() - + # Print clicked coordinates - print(f"Clicked at coordinates: ({y},{x}). Selected token index: {self.selected_token_idx} at ({y_snapped*8+4},{x_snapped*8+4})") + print( + f"Clicked at coordinates: ({y},{x}). Selected token index: {self.selected_token_idx} at ({y_snapped * 8 + 4},{x_snapped * 8 + 4})" + ) finally: # Reset the flag after updating self._updating_points = False @@ -824,19 +849,15 @@ def _on_add_point(self): else: print("No map coordinates available. Please load a file first.") - - def update_heatmap(self, y: IntType, x: IntType): """Update the heatmap based on clicked position""" # Generate new heatmap using the provided or default function self.heatmap, self.selected_token_idx = self._heatmap_generator(y, x) - + # Update the heatmap layer self.heatmap_layer.data = self.heatmap - def get_token_attention(self, selected_token_idx: IntType): - # Get selected direction selected_direction = self.attention_handling.currentData() @@ -844,54 +865,64 @@ def get_token_attention(self, selected_token_idx: IntType): # Attention from selected to others if selected_direction == 0: - token_attn = self.attention_map[selected_token_idx+1, 1:] # +1 to skip the cls token + token_attn = self.attention_map[ + selected_token_idx + 1, 1: + ] # +1 to skip the cls token # Attention from others to selected elif selected_direction == 1: - token_attn = self.attention_map[1:, selected_token_idx+1] # +1 to skip the cls token + token_attn = self.attention_map[ + 1:, selected_token_idx + 1 + ] # +1 to skip the cls token # Deviation of overall given attention elif selected_direction == 2: token_attn = torch.std(self.attention_map[1:, 1:], dim=0) percentile_low = np.percentile(token_attn, 0.5) percentile_high = np.percentile(token_attn, 99.5) - token_attn = (token_attn - percentile_low) / (percentile_high - percentile_low + 1e-8) + token_attn = (token_attn - percentile_low) / ( + percentile_high - percentile_low + 1e-8 + ) # Deviation of overall received attention elif selected_direction == 3: token_attn = torch.std(self.attention_map[1:, 1:], dim=1) percentile_low = np.percentile(token_attn, 0.5) percentile_high = np.percentile(token_attn, 99.5) - token_attn = (token_attn - percentile_low) / (percentile_high - percentile_low + 1e-8) - + token_attn = (token_attn - percentile_low) / ( + percentile_high - percentile_low + 1e-8 + ) + # Mean of overall given attention elif selected_direction == 4: token_attn = torch.mean(self.attention_map[1:, 1:], dim=0) - + # Mean of overall received attention elif selected_direction == 5: token_attn = torch.mean(self.attention_map[1:, 1:], dim=1) # Class token attention elif selected_direction == 6: - token_attn = self.attention_map[0, 1:] # from cls to others + token_attn = self.attention_map[0, 1:] # from cls to others percentile_low = np.percentile(token_attn, 0.5) percentile_high = np.percentile(token_attn, 99.5) - token_attn = (token_attn - percentile_low) / (percentile_high - percentile_low + 1e-8) + token_attn = (token_attn - percentile_low) / ( + percentile_high - percentile_low + 1e-8 + ) else: raise ValueError(f"Invalid direction selected: {selected_direction}") - + token_attn = np.clip(token_attn, 0, 1) return token_attn - - def _heatmap_generator(self, y: IntType, x: IntType): """Heatmap generator - determines closest token to clicked position and extract inter-token attention""" # Get the closest token to the clicked position - token_distances = cdist([(x, y)], self.map_coords.numpy(force=True)*8) # Upscale by 8 to match thumbnail size + token_distances = cdist( + [(x, y)], self.map_coords.numpy(force=True) * 8 + ) # Upscale by 8 to match thumbnail size selected_token_idx = np.argmin(token_distances) # Get attention for selected token @@ -904,7 +935,9 @@ def _heatmap_generator(self, y: IntType, x: IntType): ).squeeze(-1) # Shape: [width, height] # Upscale by 8 to match the thumbnail size - cls_attn_map = cls_attn_map.repeat_interleave(8, dim=0).repeat_interleave(8, dim=1) + cls_attn_map = cls_attn_map.repeat_interleave(8, dim=0).repeat_interleave( + 8, dim=1 + ) # Normalize the heatmap to [0, 1] # cls_attn_map = (cls_attn_map - cls_attn_map.min()) / (cls_attn_map.max() - cls_attn_map.min() + 1e-8) @@ -913,24 +946,22 @@ def _heatmap_generator(self, y: IntType, x: IntType): heatmap_values = cls_attn_map.numpy(force=True) # Get the colormap - colormap = cm.get_cmap('inferno') - + colormap = cm.get_cmap("inferno") + # Apply colormap to the values to get RGB heatmap_rgba = colormap(heatmap_values) - + # Create a mask for zero and near-zero values (make these transparent) threshold = 0.0 zero_mask = heatmap_values < threshold - + # Set alpha channel to make zero-value regions fully transparent heatmap_rgba[zero_mask, 3] = 0.0 - - # Scale the alpha for non-zero values by the desired opacity + + # Scale the alpha for non-zero values by the desired opacity heatmap_rgba[~zero_mask, 3] *= 1.0 - + return heatmap_rgba, selected_token_idx - - def show(self): """Display the viewer and start the event loop""" @@ -938,17 +969,17 @@ def show(self): def show_attention_ui( - feature_dir: Path, - wsis_to_process: Iterable[str], - checkpoint_path: Path, - output_dir: Path, - slide_paths: Iterable[Path] | None, - device: DeviceLikeType, - default_slide_mpp: SlideMPP | None - ): + feature_dir: Path, + wsis_to_process: Iterable[str], + checkpoint_path: Path, + output_dir: Path, + slide_paths: Iterable[Path] | None, + device: DeviceLikeType, + default_slide_mpp: SlideMPP | None, +): """ Launch the attention UI. - + Parameters: ----------- feature_dir : Path @@ -973,6 +1004,6 @@ def show_attention_ui( output_dir, slide_paths, device, - default_slide_mpp + default_slide_mpp, ) - napari.run() \ No newline at end of file + napari.run() diff --git a/src/stamp/modeling/vision_transformer.py b/src/stamp/modeling/vision_transformer.py index 5b035b8a..8647c98b 100755 --- a/src/stamp/modeling/vision_transformer.py +++ b/src/stamp/modeling/vision_transformer.py @@ -61,7 +61,13 @@ def forward( # Help, my abstractions are leaking! alibi_mask: Bool[Tensor, "batch sequence sequence"], return_attention: bool = False, - ) -> Float[Tensor, "batch sequence proj_feature"] | tuple[Float[Tensor, "batch sequence proj_feature"], Float[Tensor, "batch heads sequence sequence"]]: + ) -> ( + Float[Tensor, "batch sequence proj_feature"] + | tuple[ + Float[Tensor, "batch sequence proj_feature"], + Float[Tensor, "batch heads sequence sequence"], + ] + ): """ Args: attn_mask: @@ -125,11 +131,17 @@ def forward( ) # Create dummy attention weights to satisfy type checking if return_attention: - print(f"Warning: Failed to return attention weights ({type(e).__name__}: {e}). Creating dummy weights.") + print( + f"Warning: Failed to return attention weights ({type(e).__name__}: {e}). Creating dummy weights." + ) batch_size, seq_len, _ = x.shape self.last_attn_weights = torch.zeros( - batch_size, self.heads, seq_len, seq_len, - device=x.device, dtype=x.dtype + batch_size, + self.heads, + seq_len, + seq_len, + device=x.device, + dtype=x.dtype, ) else: attn_output = self.mhsa( @@ -151,11 +163,15 @@ def forward( # Create default attention weights if none were produced batch_size, seq_len, _ = x.shape self.last_attn_weights = torch.zeros( - batch_size, self.heads if hasattr(self, 'heads') else 1, - seq_len, seq_len, device=x.device, dtype=x.dtype + batch_size, + self.heads if hasattr(self, "heads") else 1, + seq_len, + seq_len, + device=x.device, + dtype=x.dtype, ) return attn_output, self.last_attn_weights - + return attn_output @@ -203,27 +219,35 @@ def forward( attn_mask: Bool[Tensor, "batch sequence sequence"] | None, alibi_mask: Bool[Tensor, "batch sequence sequence"], return_attention: bool = False, - ) -> Float[Tensor, "batch sequence proj_feature"] | tuple[Float[Tensor, "batch sequence proj_feature"], list[Float[Tensor, "batch heads sequence sequence"]]]: + ) -> ( + Float[Tensor, "batch sequence proj_feature"] + | tuple[ + Float[Tensor, "batch sequence proj_feature"], + list[Float[Tensor, "batch heads sequence sequence"]], + ] + ): attention_weights = [] - + for attn, ff in cast(Iterable[tuple[SelfAttention, nn.Module]], self.layers): if return_attention: x_attn, attn_weights = attn( - x, - coords=coords, - attn_mask=attn_mask, + x, + coords=coords, + attn_mask=attn_mask, alibi_mask=alibi_mask, - return_attention=True + return_attention=True, ) attention_weights.append(attn_weights) else: - x_attn = attn(x, coords=coords, attn_mask=attn_mask, alibi_mask=alibi_mask) - + x_attn = attn( + x, coords=coords, attn_mask=attn_mask, alibi_mask=alibi_mask + ) + x = x_attn + x x = ff(x) + x x = self.norm(x) - + if return_attention: return x, attention_weights return x @@ -314,8 +338,7 @@ def forward( bags = bags[:, 0] return self.mlp_head(bags) - - + def get_attention_maps( self, bags: Float[Tensor, "batch tile feature"], @@ -339,7 +362,11 @@ def get_attention_maps( # Create necessary masks if mask is None: bags, attention_weights = self.transformer( - bags, coords=coords, attn_mask=None, alibi_mask=None, return_attention=True + bags, + coords=coords, + attn_mask=None, + alibi_mask=None, + return_attention=True, ) else: mask_with_class_token = torch.cat( @@ -361,8 +388,8 @@ def get_attention_maps( coords=coords, attn_mask=square_attn_mask, alibi_mask=alibi_mask, - return_attention=True + return_attention=True, ) - + # Return the attention weights return attention_weights From 44386d9691391fb3562d670f8fdcc754d35a7ca6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Georg=20W=C3=B6lflein?= Date: Wed, 11 Jun 2025 23:18:13 +0200 Subject: [PATCH 07/10] Appease mypy (fix some typing issues) --- src/stamp/modeling/vision_transformer.py | 71 ++++++++++++++++++++---- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/src/stamp/modeling/vision_transformer.py b/src/stamp/modeling/vision_transformer.py index 8647c98b..0f81f366 100755 --- a/src/stamp/modeling/vision_transformer.py +++ b/src/stamp/modeling/vision_transformer.py @@ -3,7 +3,7 @@ """ from collections.abc import Iterable -from typing import assert_never, cast +from typing import Literal, assert_never, cast, overload import torch from beartype import beartype @@ -41,16 +41,40 @@ def __init__( super().__init__() self.heads = num_heads self.norm = nn.LayerNorm(dim) - self.last_attn_weights = None # Store attention weights if use_alibi: - self.mhsa = MultiHeadALiBi( + self.mhsa: MultiHeadALiBi | nn.MultiheadAttention = MultiHeadALiBi( embed_dim=dim, num_heads=num_heads, ) else: self.mhsa = nn.MultiheadAttention(dim, num_heads, dropout, batch_first=True) + @overload + def forward( + self, + x: Float[Tensor, "batch sequence proj_feature"], + *, + coords: Float[Tensor, "batch sequence xy"], + attn_mask: Bool[Tensor, "batch sequence sequence"] | None, + alibi_mask: Bool[Tensor, "batch sequence sequence"], + return_attention: Literal[False] = False, + ) -> Float[Tensor, "batch sequence proj_feature"]: ... + + @overload + def forward( + self, + x: Float[Tensor, "batch sequence proj_feature"], + *, + coords: Float[Tensor, "batch sequence xy"], + attn_mask: Bool[Tensor, "batch sequence sequence"] | None, + alibi_mask: Bool[Tensor, "batch sequence sequence"], + return_attention: Literal[True], + ) -> tuple[ # if return_attention is True, return the attention weights as well + Float[Tensor, "batch sequence proj_feature"], + Float[Tensor, "batch heads sequence sequence"], + ]: ... + @jaxtyped(typechecker=beartype) def forward( self, @@ -85,7 +109,7 @@ def forward( x = self.norm(x) # Initialize attention weights with default shape - self.last_attn_weights = None + last_attn_weights: Float[Tensor, "batch heads sequence sequence"] | None = None match self.mhsa: case nn.MultiheadAttention(): @@ -101,7 +125,7 @@ def forward( else None ), ) - self.last_attn_weights = attn_weights + last_attn_weights = attn_weights case MultiHeadALiBi(): # Modified MultiHeadALiBi to return attention weights @@ -117,7 +141,7 @@ def forward( alibi_mask=alibi_mask, return_attention=True, ) - self.last_attn_weights = attn_weights + last_attn_weights = attn_weights except (TypeError, ValueError, RuntimeError) as e: # If the return_attention param exists but fails, fall back attn_output = self.mhsa( @@ -135,7 +159,7 @@ def forward( f"Warning: Failed to return attention weights ({type(e).__name__}: {e}). Creating dummy weights." ) batch_size, seq_len, _ = x.shape - self.last_attn_weights = torch.zeros( + last_attn_weights = torch.zeros( batch_size, self.heads, seq_len, @@ -153,16 +177,16 @@ def forward( attn_mask=attn_mask, alibi_mask=alibi_mask, ) - self.last_attn_weights = None + last_attn_weights = None case _ as unreachable: assert_never(unreachable) if return_attention: # Ensure we always return valid tensor for attention weights - if self.last_attn_weights is None: + if last_attn_weights is None: # Create default attention weights if none were produced batch_size, seq_len, _ = x.shape - self.last_attn_weights = torch.zeros( + last_attn_weights = torch.zeros( batch_size, self.heads if hasattr(self, "heads") else 1, seq_len, @@ -170,7 +194,7 @@ def forward( device=x.device, dtype=x.dtype, ) - return attn_output, self.last_attn_weights + return attn_output, last_attn_weights return attn_output @@ -210,6 +234,31 @@ def __init__( self.norm = nn.LayerNorm(dim) + @overload + def forward( + self, + x: Float[Tensor, "batch sequence proj_feature"], + *, + coords: Float[Tensor, "batch sequence 2"], + attn_mask: Bool[Tensor, "batch sequence sequence"] | None, + alibi_mask: Bool[Tensor, "batch sequence sequence"], + return_attention: Literal[False] = False, + ) -> Float[Tensor, "batch sequence proj_feature"]: ... + + @overload + def forward( + self, + x: Float[Tensor, "batch sequence proj_feature"], + *, + coords: Float[Tensor, "batch sequence 2"], + attn_mask: Bool[Tensor, "batch sequence sequence"] | None, + alibi_mask: Bool[Tensor, "batch sequence sequence"], + return_attention: Literal[True], + ) -> tuple[ + Float[Tensor, "batch sequence proj_feature"], + list[Float[Tensor, "batch heads sequence sequence"]], + ]: ... + @jaxtyped(typechecker=beartype) def forward( self, From feded3bbf55b0e5757e0fd4474c0a5af28debf41 Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Wed, 6 Aug 2025 10:36:48 +0200 Subject: [PATCH 08/10] fixed selected tile visualization when choosing cls token --- src/stamp/heatmaps/attention_ui.py | 47 ++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/src/stamp/heatmaps/attention_ui.py b/src/stamp/heatmaps/attention_ui.py index 99e4db73..ea74dcb4 100644 --- a/src/stamp/heatmaps/attention_ui.py +++ b/src/stamp/heatmaps/attention_ui.py @@ -582,17 +582,27 @@ def process_selected_file(self, wsi_path): .float() .to(self.device) ) + coords_um = get_coords(h5).coords_um + if not isinstance(coords_um, torch.Tensor): + coords_um = torch.tensor(coords_um, dtype=torch.float32) + stride_um = Microns(get_stride(coords_um)) - if h5.attrs.get("unit") == "um": - self.tile_size_slide_px = SlidePixels( - int(round(cast(float, h5.attrs["tile_size"]) / slide_mpp)) - ) - else: - self.tile_size_slide_px = SlidePixels( + # list all h5 attrs + for key in h5.attrs.keys(): + print(f"h5 attribute '{key}': {h5.attrs[key]}") + + self.tile_size_slide_px = SlidePixels( int(round(256 / slide_mpp)) ) + if h5.attrs.get("unit") == "um": + for attr_name in ["tile_size_um", "tile_size"]: + if attr_name in h5.attrs: + self.tile_size_slide_px = SlidePixels( + int(round(cast(float, h5.attrs[attr_name]) / slide_mpp)) + ) + break # grid coordinates, i.e. the top-left most tile is (0, 0), the one to its right (0, 1) etc. self.map_coords = (coords_um / stride_um).round().long() @@ -719,21 +729,28 @@ def _load_tile_patches(self): (self.tile_size_slide_px, self.tile_size_slide_px), ).convert("RGB") - # Add selected patch with label + # Add layout for selected patch selected_frame = QFrame() selected_layout = QVBoxLayout() selected_frame.setLayout(selected_layout) - - # Create QLabel for image selected_label = QLabel() - selected_pixmap = _patch_to_pixmap(selected_patch) - selected_label.setPixmap(selected_pixmap) - - # Create label text - text_label = QLabel(f"Selected-ID:{self.selected_token_idx}") - text_label.setAlignment(Qt.AlignCenter) + + # Create label text and selected image + if self.attention_handling.currentData() == 6: # Class token attention (no reference image) + # Create QLabel for Class token + selected_pixmap = QPixmap(200, 200) # blank image + selected_pixmap.fill(Qt.transparent) + text_label = QLabel(f"Selected: Class Token Attention") + else: + # Create QLabel for image + selected_pixmap = _patch_to_pixmap(selected_patch) + text_label = QLabel(f"Selected-ID:{self.selected_token_idx}") + + # Add selected patch and label to layout + selected_label.setPixmap(selected_pixmap) selected_layout.addWidget(selected_label) + text_label.setAlignment(Qt.AlignCenter) selected_layout.addWidget(text_label) self.patches_layout.addWidget(selected_frame) From c9fca6a81e6a219a637bd82eeb3d92810611df9f Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Wed, 6 Aug 2025 12:05:06 +0200 Subject: [PATCH 09/10] added save functionality for top-k tiles --- src/stamp/heatmaps/attention_ui.py | 130 ++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 12 deletions(-) diff --git a/src/stamp/heatmaps/attention_ui.py b/src/stamp/heatmaps/attention_ui.py index ea74dcb4..3c9df693 100644 --- a/src/stamp/heatmaps/attention_ui.py +++ b/src/stamp/heatmaps/attention_ui.py @@ -12,6 +12,8 @@ QSlider, QHBoxLayout, QFrame, + QLineEdit, + QFileDialog, ) from qtpy.QtCore import Qt @@ -194,6 +196,7 @@ def __init__( # Add other UI elements self._add_file_selection_ui() self._add_config_selection_ui() + self._add_topk_controls_ui() self._add_patch_display_widget() # Disable UI elements until a file is selected @@ -360,6 +363,18 @@ def _add_config_selection_ui(self): layout.addLayout(head_layout) + # === Add the widget to napari viewer as a dock widget === + self.viewer.window.add_dock_widget( + selection_widget, name="Attention Parameters", area="right" + ) + + def _add_topk_controls_ui(self): + """Add UI controls for top-k tile selection and patch operations""" + # Create a widget container + topk_widget = QWidget() + layout = QVBoxLayout() + topk_widget.setLayout(layout) + # === Top-k SELECTION === topk_label = QLabel("Top-k tiles to highlight:") layout.addWidget(topk_label) @@ -404,20 +419,37 @@ def _add_config_selection_ui(self): layout.addLayout(topk_layout) - # Add a horizontal separator - separator = QFrame() - separator.setFrameShape(QFrame.HLine) - separator.setFrameShadow(QFrame.Sunken) - layout.addWidget(separator) - # === LOAD PATCHES BUTTON === - self.load_patches_btn = QPushButton("Load Selected & Top-k Patches") + self.load_patches_btn = QPushButton("Visualize Top-k Patches") self.load_patches_btn.clicked.connect(self._load_tile_patches) layout.addWidget(self.load_patches_btn) + + save_path_layout = QHBoxLayout() + + # Path entry field + self.save_path_entry = QLineEdit() + self.save_path_entry.setPlaceholderText("Enter path to save patches...") + # Set default save path to output directory if available + if hasattr(self, 'output_dir') and self.output_dir: + self.save_path_entry.setText(str(Path(self.output_dir) / 'topK_patches')) + save_path_layout.addWidget(self.save_path_entry) + + # Browse button + self.browse_btn = QPushButton("Browse Save Path") + self.browse_btn.setMaximumWidth(120) + self.browse_btn.clicked.connect(self._browse_save_path) + save_path_layout.addWidget(self.browse_btn) + + layout.addLayout(save_path_layout) + + # === SAVE TOP-K BUTTON === + self.save_topk_btn = QPushButton("Save Top-k Patches") + self.save_topk_btn.clicked.connect(self._save_topk_patches) + layout.addWidget(self.save_topk_btn) # === Add the widget to napari viewer as a dock widget === self.viewer.window.add_dock_widget( - selection_widget, name="Attention Parameters", area="right" + topk_widget, name="Top-k Tile Controls", area="right" ) def _add_patch_display_widget(self): @@ -460,6 +492,16 @@ def _set_ui_enabled(self, enabled: bool): if hasattr(self, "load_patches_btn"): self.load_patches_btn.setEnabled(enabled) + # Enable save top-k button + if hasattr(self, "save_topk_btn"): + self.save_topk_btn.setEnabled(enabled) + + # Enable save path entry and browse button + if hasattr(self, "save_path_entry"): + self.save_path_entry.setEnabled(enabled) + if hasattr(self, "browse_btn"): + self.browse_btn.setEnabled(enabled) + # Enable attention handling dropdown if hasattr(self, "attention_handling"): self.attention_handling.setEnabled(enabled) @@ -589,10 +631,6 @@ def process_selected_file(self, wsi_path): stride_um = Microns(get_stride(coords_um)) - # list all h5 attrs - for key in h5.attrs.keys(): - print(f"h5 attribute '{key}': {h5.attrs[key]}") - self.tile_size_slide_px = SlidePixels( int(round(256 / slide_mpp)) ) @@ -796,6 +834,74 @@ def _load_tile_patches(self): self.patches_container.adjustSize() QApplication.processEvents() + def _browse_save_path(self): + """Open file dialog to select save directory""" + current_path = self.save_path_entry.text() or str(Path(self.output_dir) / 'topK_patches') if hasattr(self, 'output_dir') and self.output_dir else "" + + directory = QFileDialog.getExistingDirectory( + None, + "Select Directory to Save Patches", + current_path + ) + + if directory: + self.save_path_entry.setText(directory) + + def _save_topk_patches(self): + """Save the selected patch and top-k patches to the specified directory""" + if self.selected_token_idx is None: + print("No token selected. Click on the image first.") + return + + save_path = self.save_path_entry.text().strip() + if not save_path: + print("Please specify a save path.") + return + + save_dir = Path(save_path) + save_dir.mkdir(parents=True, exist_ok=True) + + try: + # Save selected patch + selected_patch = self.slide.read_region( + tuple(self.coords_tile_slide_px[self.selected_token_idx].tolist()), + 0, + (self.tile_size_slide_px, self.tile_size_slide_px), + ).convert("RGB") + + # Create filename prefix based on slide and attention parameters + filename_prefix = f"{Path(self.selected_filename).stem}_layer{self.layer_slider.value()}_head{self.head_slider.value()}_mode{self.attention_handling.currentData()}" + + if self.attention_handling.currentData() == 6: # Class token attention + # For class token, we want to skip + filename_prefix = f"{filename_prefix}_cls" + else: + filename_prefix = f"{filename_prefix}_token{self.selected_token_idx}" + selected_filename = save_dir / f"{filename_prefix}.png" + selected_patch.save(selected_filename) + + # Save top-k patches + topk = min(self.topk_slider.value(), len(self.token_attn)) + if topk > 0: + saved_count = 0 + for n, (score, index) in enumerate(zip(*self.token_attn.topk(topk))): + # Get patch + patch = self.slide.read_region( + tuple(self.coords_tile_slide_px[index].tolist()), + 0, + (self.tile_size_slide_px, self.tile_size_slide_px), + ).convert("RGB") + + # Save patch + patch_filename = save_dir / f"{filename_prefix}_top{n+1}_token{index}_score{score:.3f}.png" + patch.save(patch_filename) + saved_count += 1 + + print(f"Saved {saved_count} top-k patches to: {save_dir}") + + except Exception as e: + print(f"Error saving patches: {e}") + def _on_update_attention_map(self): # Check if we have data to display if self.attention_weights is None: From 64c255e7f1ccc62302b955e34ec16a23e94f5ef7 Mon Sep 17 00:00:00 2001 From: DEschweiler Date: Wed, 13 Aug 2025 16:00:36 +0200 Subject: [PATCH 10/10] slider improvements --- src/stamp/heatmaps/attention_ui.py | 46 +++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/src/stamp/heatmaps/attention_ui.py b/src/stamp/heatmaps/attention_ui.py index 3c9df693..63242e1c 100644 --- a/src/stamp/heatmaps/attention_ui.py +++ b/src/stamp/heatmaps/attention_ui.py @@ -162,7 +162,9 @@ def __init__( self.selected_token_idx = None self.selected_filename = None self.num_layer = 0 + self.n_layers = 2 self.num_head = 0 + self.n_heads = 8 # Initialize empty heatmap self.heatmap = np.zeros((self.height, self.width, 4), dtype=float) @@ -266,6 +268,8 @@ def _add_config_selection_ui(self): self.attention_handling.addItem("Mean of overall given attention", 4) self.attention_handling.addItem("Mean of overall received attention", 5) self.attention_handling.addItem("Class token attention", 6) + self.attention_handling.addItem("Mutual attention", 7) + self.attention_handling.addItem("Max mutual attention", 8) # Connect the dropdown to the update function self.attention_handling.currentIndexChanged.connect( @@ -285,11 +289,11 @@ def _add_config_selection_ui(self): # Layer slider self.layer_slider = QSlider(Qt.Horizontal) self.layer_slider.setMinimum(0) - self.layer_slider.setMaximum(2) # Will be updated with actual layer count + self.layer_slider.setMaximum(self.n_layers-1) # Will be updated with actual layer count self.layer_slider.setValue(0) self.layer_slider.valueChanged.connect( lambda value: ( - self.layer_value_label.setText(str(value)), + self.layer_value_label.setText(str(value+1) + f"/{self.n_layers}"), self._on_update_attention_map(), ) ) @@ -312,7 +316,7 @@ def _add_config_selection_ui(self): layer_layout.addWidget(self.layer_right_btn) # Layer value display - self.layer_value_label = QLabel("0") + self.layer_value_label = QLabel("1/2") self.layer_value_label.setMinimumWidth(25) self.layer_value_label.setAlignment(Qt.AlignCenter) layer_layout.addWidget(self.layer_value_label) @@ -320,7 +324,7 @@ def _add_config_selection_ui(self): layout.addLayout(layer_layout) # === HEAD SELECTION === - head_label = QLabel("Number of Prediction Head\n(-1 for average):") + head_label = QLabel("Number of Prediction Head\n(0 for average):") layout.addWidget(head_label) # Create head selection controls with arrows and slider @@ -329,11 +333,11 @@ def _add_config_selection_ui(self): # Head slider self.head_slider = QSlider(Qt.Horizontal) self.head_slider.setMinimum(-1) - self.head_slider.setMaximum(8) # Will be updated with actual head count - self.head_slider.setValue(0) + self.head_slider.setMaximum(self.n_heads-1) # Will be updated with actual head count + self.head_slider.setValue(-1) self.head_slider.valueChanged.connect( lambda value: ( - self.head_value_label.setText(str(value)), + self.head_value_label.setText(str(value+1) + f"/{self.n_heads}"), self._on_update_attention_map(), ) ) @@ -356,7 +360,7 @@ def _add_config_selection_ui(self): head_layout.addWidget(self.head_right_btn) # Head value display - self.head_value_label = QLabel("0") + self.head_value_label = QLabel("1/8") self.head_value_label.setMinimumWidth(25) self.head_value_label.setAlignment(Qt.AlignCenter) head_layout.addWidget(self.head_value_label) @@ -660,10 +664,13 @@ def process_selected_file(self, wsi_path): ) # Determine number of heads and layers and update UI elements - num_layers = len(self.attention_weights) - num_heads = self.attention_weights[0].shape[1] - self.layer_slider.setMaximum(num_layers - 1) - self.head_slider.setMaximum(num_heads - 1) + self.n_layers = len(self.attention_weights) + self.n_heads = self.attention_weights[0].shape[1] + self.layer_slider.setMaximum(self.n_layers-1) + self.head_slider.setMaximum(self.n_heads-1) + self.layer_value_label.setText(f"{self.layer_slider.value()+1}/{self.n_layers}") + self.head_value_label.setText(f"{self.head_slider.value()+1}/{self.n_heads}") + # Get thumbnail of the slide self.image = _get_thumb(self.slide, slide_mpp) @@ -693,6 +700,9 @@ def load_selected_attention_map(self): # Cut out batch dimension self.attention_map = self.attention_map[0, ...] # Shape: [tokens, tokens] + # Take absolute values to account positive and negative attention similarly + self.attention_map = self.attention_map.abs() + # Normalize attention map to [0, 1] by using percentiles (not considering cls token) percentile_low = np.percentile(self.attention_map[1:, 1:], 0.5) percentile_high = np.percentile(self.attention_map[1:, 1:], 99.5) @@ -774,7 +784,7 @@ def _load_tile_patches(self): selected_label = QLabel() # Create label text and selected image - if self.attention_handling.currentData() == 6: # Class token attention (no reference image) + if self.attention_handling.currentData() in (2,3, 4, 5, 6, 8): # If heatmap type is agnostic to token selection # Create QLabel for Class token selected_pixmap = QPixmap(200, 200) # blank image selected_pixmap.fill(Qt.transparent) @@ -1033,6 +1043,16 @@ def get_token_attention(self, selected_token_idx: IntType): percentile_high - percentile_low + 1e-8 ) + # Mutual attention + elif selected_direction == 7: + mutual_attn_matrix = self.attention_map[1:, 1:] * self.attention_map[1:, 1:].T + token_attn = mutual_attn_matrix[:,selected_token_idx] + + # Mean mutual attention + elif selected_direction == 8: + mutual_attn_matrix = self.attention_map[1:, 1:] * self.attention_map[1:, 1:].T + token_attn = torch.max(mutual_attn_matrix, dim=1)[0] + else: raise ValueError(f"Invalid direction selected: {selected_direction}")