Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import gc
import random
import tempfile
from pathlib import Path

import numpy as np
import pytest
Expand Down Expand Up @@ -49,3 +51,10 @@ def pytest_sessionfinish(session, exitstatus):
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()


@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
98 changes: 97 additions & 1 deletion transformer_lens/benchmarks/main_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,94 @@ def cleanup_model(model, model_name_str: str):
return results


def update_model_registry(model_name: str, results: List[BenchmarkResult]) -> bool:
"""Update the model registry with benchmark results.

Args:
model_name: The model that was benchmarked
results: List of benchmark results

Returns:
True if registry was updated successfully
"""
import json
from datetime import datetime
from pathlib import Path

registry_path = (
Path(__file__).parent.parent / "tools" / "model_registry" / "data" / "supported_models.json"
)

if not registry_path.exists():
print(f"Registry not found at {registry_path}")
return False

# Calculate phase scores (percentage of passed tests per phase)
phase_results: Dict[int, List[bool]] = {1: [], 2: [], 3: []}
for result in results:
if result.phase in phase_results and result.severity != BenchmarkSeverity.SKIPPED:
phase_results[result.phase].append(result.passed)

phase_scores = {}
for phase, passed_list in phase_results.items():
if passed_list:
phase_scores[phase] = round(sum(passed_list) / len(passed_list) * 100, 1)
else:
phase_scores[phase] = None

# Load registry
with open(registry_path) as f:
registry = json.load(f)

# Find and update the model entry
updated = False
for entry in registry.get("models", []):
if entry["model_id"] == model_name:
entry["verified"] = True
entry["verified_at"] = datetime.now().isoformat()
entry["phase1_score"] = phase_scores.get(1)
entry["phase2_score"] = phase_scores.get(2)
entry["phase3_score"] = phase_scores.get(3)
updated = True
break

if not updated:
# Model not in registry - add it
# Try to determine architecture from results or model name
architecture_id = "Unknown"
try:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_name)
archs = getattr(config, "architectures", []) or []
if archs:
architecture_id = archs[0]
except Exception:
pass

registry.setdefault("models", []).append(
{
"model_id": model_name,
"architecture_id": architecture_id,
"verified": True,
"verified_at": datetime.now().isoformat(),
"phase1_score": phase_scores.get(1),
"phase2_score": phase_scores.get(2),
"phase3_score": phase_scores.get(3),
}
)
updated = True

# Write back
with open(registry_path, "w") as f:
json.dump(registry, f, indent=2)

print(
f"Updated registry for {model_name}: P1={phase_scores.get(1)}%, P2={phase_scores.get(2)}%, P3={phase_scores.get(3)}%"
)
return updated


def main():
"""Run benchmarks from command line."""
import argparse
Expand Down Expand Up @@ -1392,10 +1480,15 @@ def main():
action="store_true",
help="Suppress verbose output",
)
parser.add_argument(
"--update-registry",
action="store_true",
help="Update model registry with benchmark results (default: false)",
)

args = parser.parse_args()

run_benchmark_suite(
results = run_benchmark_suite(
model_name=args.model,
device=args.device,
use_hf_reference=not args.no_hf_reference,
Expand All @@ -1404,6 +1497,9 @@ def main():
verbose=not args.quiet,
)

if args.update_registry:
update_model_registry(args.model, results)


if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions transformer_lens/config/TransformerBridgeConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def __init__(
eps_attr: str = "eps",
rmsnorm_uses_offset: bool = False,
attn_implementation: Optional[str] = None,
# Multimodal configuration
is_multimodal: bool = False,
vision_hidden_size: Optional[int] = None,
vision_num_layers: Optional[int] = None,
vision_num_heads: Optional[int] = None,
mm_tokens_per_image: Optional[int] = None,
**kwargs,
):
"""Initialize TransformerBridgeConfig."""
Expand Down Expand Up @@ -166,6 +172,12 @@ def __init__(
self.eps_attr = eps_attr
self.rmsnorm_uses_offset = rmsnorm_uses_offset
self.attn_implementation = attn_implementation
# Multimodal configuration
self.is_multimodal = is_multimodal
self.vision_hidden_size = vision_hidden_size
self.vision_num_layers = vision_num_layers
self.vision_num_heads = vision_num_heads
self.mm_tokens_per_image = mm_tokens_per_image

self.__post_init__()

Expand Down
2 changes: 2 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Qwen2ArchitectureAdapter,
Qwen3ArchitectureAdapter,
QwenArchitectureAdapter,
StableLmArchitectureAdapter,
T5ArchitectureAdapter,
)

Expand Down Expand Up @@ -56,6 +57,7 @@
"QwenForCausalLM": QwenArchitectureAdapter,
"Qwen2ForCausalLM": Qwen2ArchitectureAdapter,
"Qwen3ForCausalLM": Qwen3ArchitectureAdapter,
"StableLmForCausalLM": StableLmArchitectureAdapter,
"T5ForConditionalGeneration": T5ArchitectureAdapter,
"NanoGPTForCausalLM": NanogptArchitectureAdapter,
"MinGPTForCausalLM": MingptArchitectureAdapter,
Expand Down
4 changes: 4 additions & 0 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,3 +2141,7 @@ def get_params(self):
ValueError: If configuration is inconsistent (e.g., cfg.n_layers != len(blocks))
"""
return get_bridge_params(self)

# NOTE: list_supported_models and check_model_support are attached to this class
# dynamically by transformer_lens.model_bridge.sources.transformers module.
# These are HuggingFace-specific methods that belong in the transformers source module.
10 changes: 10 additions & 0 deletions transformer_lens/model_bridge/generalized_components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
BloomAttentionBridge,
)
from transformer_lens.model_bridge.generalized_components.bloom_mlp import BloomMLPBridge
from transformer_lens.model_bridge.generalized_components.siglip_vision_encoder import (
SiglipVisionEncoderBridge,
SiglipVisionEncoderLayerBridge,
)
from transformer_lens.model_bridge.generalized_components.vision_projection import (
VisionProjectionBridge,
)

__all__ = [
"AttentionBridge",
Expand All @@ -59,4 +66,7 @@
"SymbolicBridge",
"UnembeddingBridge",
"T5BlockBridge",
"SiglipVisionEncoderBridge",
"SiglipVisionEncoderLayerBridge",
"VisionProjectionBridge",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""SigLIP Vision Encoder bridge component.

This module contains the bridge component for SigLIP vision encoder layers
used in multimodal models like Gemma 3 and MedGemma.
"""
from typing import Any, Dict, Optional

import torch

from transformer_lens.hook_points import HookPoint
from transformer_lens.model_bridge.generalized_components.base import (
GeneralizedComponent,
)
from transformer_lens.model_bridge.generalized_components.normalization import (
NormalizationBridge,
)


class SiglipVisionEncoderLayerBridge(GeneralizedComponent):
"""Bridge for a single SigLIP encoder layer.

SigLIP encoder layers have:
- layer_norm1: LayerNorm
- self_attn: SiglipAttention
- layer_norm2: LayerNorm
- mlp: SiglipMLP
"""

is_list_item: bool = True
hook_aliases = {
"hook_resid_pre": "hook_in",
"hook_resid_post": "hook_out",
"hook_attn_in": "attn.hook_in",
"hook_attn_out": "attn.hook_out",
"hook_mlp_in": "mlp.hook_in",
"hook_mlp_out": "mlp.hook_out",
}

def __init__(
self,
name: str,
config: Optional[Any] = None,
submodules: Optional[Dict[str, GeneralizedComponent]] = None,
):
"""Initialize the SigLIP encoder layer bridge.

Args:
name: The name of this component (e.g., "encoder.layers")
config: Optional configuration object
submodules: Dictionary of submodules to register
"""
super().__init__(name, config, submodules=submodules or {})

def forward(self, hidden_states: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Forward pass through the vision encoder layer.

Args:
hidden_states: Input hidden states from previous layer
**kwargs: Additional arguments (attention_mask, etc.)

Returns:
Output hidden states
"""
if self.original_component is None:
raise RuntimeError(
f"Original component not set for {self.name}. Call set_original_component() first."
)

hidden_states = self.hook_in(hidden_states)
output = self.original_component(hidden_states, **kwargs)

if isinstance(output, tuple):
output = (self.hook_out(output[0]),) + output[1:]
else:
output = self.hook_out(output)

return output


class SiglipVisionEncoderBridge(GeneralizedComponent):
"""Bridge for the complete SigLIP vision encoder.

The SigLIP vision tower consists of:
- vision_model.embeddings: Patch + position embeddings
- vision_model.encoder.layers[]: Stack of encoder layers
- post_layernorm: Final layer norm

This bridge wraps the entire vision tower to provide hooks for
interpretability of the vision processing pipeline.
"""

hook_aliases = {
"hook_vision_embed": "embeddings.hook_out",
"hook_vision_out": "hook_out",
}

def __init__(
self,
name: str,
config: Optional[Any] = None,
submodules: Optional[Dict[str, GeneralizedComponent]] = None,
):
"""Initialize the SigLIP vision encoder bridge.

Args:
name: The name of this component (e.g., "vision_tower")
config: Optional configuration object
submodules: Dictionary of submodules to register
"""
default_submodules = {
"embeddings": GeneralizedComponent(name="vision_model.embeddings"),
"encoder_layers": SiglipVisionEncoderLayerBridge(name="vision_model.encoder.layers"),
"post_layernorm": NormalizationBridge(
name="vision_model.post_layernorm", config=config
),
}

if submodules:
default_submodules.update(submodules)

super().__init__(name, config, submodules=default_submodules)

# Additional hooks for vision-specific processing
self.hook_patch_embed = HookPoint() # After patch embedding
self.hook_pos_embed = HookPoint() # After position embedding added

def forward(
self,
pixel_values: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
"""Forward pass through the vision encoder.

Args:
pixel_values: Input image tensor [batch, channels, height, width]
**kwargs: Additional arguments

Returns:
Vision embeddings [batch, num_patches, hidden_size]
"""
if self.original_component is None:
raise RuntimeError(
f"Original component not set for {self.name}. Call set_original_component() first."
)

# Apply input hook to pixel values
pixel_values = self.hook_in(pixel_values)

# Forward through the vision tower
output = self.original_component(pixel_values, **kwargs)

# Handle tuple output (some models return (hidden_states, ...))
if isinstance(output, tuple):
output = (self.hook_out(output[0]),) + output[1:]
elif hasattr(output, "last_hidden_state"):
# Handle BaseModelOutput-like returns
output.last_hidden_state = self.hook_out(output.last_hidden_state)
else:
output = self.hook_out(output)

return output
Loading
Loading