Skip to content

Commit 24c07b8

Browse files
committed
Added basic code for CETT threshold calculation and refactored activation capture.
Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent 1e4c444 commit 24c07b8

File tree

6 files changed

+133
-538
lines changed

6 files changed

+133
-538
lines changed

generate_dataset.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from transformers import AutoModelForCausalLM, AutoTokenizer
4444
from transformers.trainer_utils import set_seed
4545

46-
from src.activation_capture import ActivationCaptureTraining
46+
from src.activation_capture import Hook
4747

4848
# Setup logging
4949
logging.basicConfig(level=logging.INFO)
@@ -120,14 +120,14 @@ def process_batch(
120120
hidden_states_dict = {}
121121
mlp_activations_dict = {}
122122
for layer_idx in range(num_layers):
123-
hidden_state = model.activation_capture.get_hidden_states(layer_idx)[0]
123+
hidden_state = model.activation_capture.mlp_activations[Hook.IN][layer_idx][0]
124124
hidden_states_dict[layer_idx] = (
125125
hidden_state.view(-1, hidden_state.shape[-1])
126126
.cpu()
127127
.numpy()
128128
.astype(np.float32)
129129
)
130-
mlp_activation = model.activation_capture.get_gate_activations(layer_idx)
130+
mlp_activation = model.activation_capture.mlp_activations[Hook.ACT][layer_idx]
131131
mlp_activations_dict[layer_idx] = (
132132
mlp_activation[0]
133133
.view(-1, mlp_activation.shape[-1])
@@ -172,8 +172,8 @@ def generate_dataset(
172172
model = model.to(device)
173173

174174
model.eval()
175-
model.activation_capture = ActivationCaptureTraining(model)
176-
model.activation_capture.register_hooks()
175+
model.activation_capture = model.ACTIVATION_CAPTURE(model)
176+
model.activation_capture.register_hooks(hooks=[Hook.IN, Hook.ACT])
177177

178178
# Get model dimensions
179179
hidden_dim = model.config.hidden_size

measure_contextual_sparsity.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from transformers.trainer_utils import set_seed
1414

1515
import matplotlib.pyplot as plt
16-
from src.activation_capture import ActivationCaptureDefault
16+
from src.activation_capture import Hook
1717

1818
# Setup logging
1919
logging.basicConfig(level=logging.INFO)
@@ -28,16 +28,14 @@ def __init__(self, model, tokenizer, device):
2828
self.tokenizer = tokenizer
2929
self.device = device
3030

31-
model.activation_capture = ActivationCaptureDefault(model)
32-
model.activation_capture.register_hooks()
31+
model.activation_capture = model.ACTIVATION_CAPTURE(model)
32+
model.activation_capture.register_hooks(hooks=[Hook.ACT])
3333
self.num_layers = len(self.model.activation_capture.get_layers())
3434

3535
self.reset_buffers()
3636

3737
def reset_buffers(self):
38-
self.mlp_sparsity = {}
39-
self.mlp_sparsity["gate"] = defaultdict(list)
40-
self.mlp_sparsity["up"] = defaultdict(list)
38+
self.mlp_sparsity = defaultdict(list)
4139
self.num_seqs = 0
4240

4341
def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
@@ -54,26 +52,19 @@ def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
5452

5553
# Compute sparsity
5654
for layer_idx in range(self.num_layers):
57-
sparsity_masks_gate = (
58-
self.model.activation_capture.get_gate_activations(layer_idx) <= 0
59-
)
60-
sparsity_masks_up = (
61-
self.model.activation_capture.get_up_activations(layer_idx) <= 0
55+
sparsity_masks = (
56+
self.model.activation_capture.mlp_activations[Hook.ACT][layer_idx] <= 0
6257
)
6358

6459
# Naive sparsity computation
6560
self.mlp_sparsity["gate"][layer_idx].append(
66-
sparsity_masks_gate.float().mean().item()
67-
)
68-
self.mlp_sparsity["up"][layer_idx].append(
69-
sparsity_masks_up.float().mean().item()
61+
sparsity_masks.float().mean().item()
7062
)
7163

7264
# Level of sparsity after union over batch dim
7365
# union_sparsity_mask = sparsity_masks.any(dim=0)
7466
# self.union_sparsity[batch_size][layer_idx].append(union_sparsity_mask.float().mean().item())
7567

76-
# TODO: Add HNSW sparsity computation for both attn heads and mlp neurons
7768
# TODO: Compute union sparsity over multiple different batch sizes
7869

7970
# Clear GPU tensors from capture to free memory

src/activation_capture.py

Lines changed: 65 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,91 @@
1-
from typing_extensions import override
2-
import torch.nn.functional as F
3-
from abc import ABC, abstractmethod
41

2+
from enum import Enum
3+
from typing import List
54

6-
class ActivationCapture(ABC):
5+
class Hook(Enum):
6+
IN = "IN"
7+
ACT = "ACT"
8+
UP = "UP"
9+
OUT = "OUT"
10+
11+
12+
class ActivationCapture():
713
"""Helper class to capture activations from model layers."""
8-
has_gate_proj: bool
9-
has_up_proj: bool
14+
hooks_available: List[Hook]
1015

1116
def __init__(self, model):
1217
self.model = model
13-
self.mlp_activations = {}
18+
self.mlp_activations = {
19+
hook: {} for hook in self.hooks_available
20+
}
1421
self.handles = []
1522

16-
@abstractmethod
17-
def _register_gate_hook(self, layer_idx, layer):
18-
pass
23+
def _register_in_hook(self, layer_idx, layer):
24+
def hook(module, input, output):
25+
# Just detach, don't clone or move to CPU yet
26+
self.mlp_activations[Hook.IN][layer_idx] = input[0].clone().detach()
27+
return output
28+
handle = layer.mlp.register_forward_hook(hook)
29+
return handle
30+
31+
def _register_act_hook(self, layer_idx, layer):
32+
def hook(module, input, output):
33+
# Just detach, don't clone or move to CPU yet
34+
self.mlp_activations[Hook.ACT][layer_idx] = input[0].clone().detach()
35+
return output
36+
handle = layer.mlp.act_fn.register_forward_hook(hook)
37+
return handle
1938

20-
@abstractmethod
2139
def _register_up_hook(self, layer_idx, layer):
22-
pass
40+
def hook(module, input, output):
41+
# Just detach, don't clone or move to CPU yet
42+
self.mlp_activations[Hook.UP][layer_idx] = input[0].clone().detach()
43+
return output
44+
handle = layer.mlp.down_proj.register_forward_hook(hook)
45+
return handle
46+
47+
def _register_out_hook(self, layer_idx, layer):
48+
def hook(module, input, output):
49+
# Just detach, don't clone or move to CPU yet
50+
self.mlp_activations[Hook.OUT][layer_idx] = output.clone().detach()
51+
return output
52+
handle = layer.mlp.register_forward_hook(hook)
53+
return handle
2354

24-
@abstractmethod
2555
def get_layers(self):
26-
pass
27-
28-
29-
@abstractmethod
30-
def get_gate_activations(self, layer_idx):
31-
"""Get combined MLP activations for a layer."""
32-
pass
56+
return self.model.get_decoder().layers
3357

34-
def register_hooks(self):
58+
def register_hooks(self, hooks=(Hook.ACT, Hook.UP, Hook.OUT)):
3559
"""Register forward hooks to capture activations."""
3660
# Clear any existing hooks
3761
self.remove_hooks()
3862

3963
# Hook into each transformer layer
40-
for i, layer in enumerate(self.get_layers()):
41-
# Capture MLP gate activations (after activation function)
42-
if self.has_gate_proj:
43-
handle = self._register_gate_hook(i, layer)
64+
for i, layer in enumerate(self.get_layers()):
65+
# Hooks capturing inputs to the MLP layer
66+
if Hook.IN in hooks and Hook.IN in self.hooks_available:
67+
handle = self._register_in_hook(i, layer)
4468
if handle is not None:
4569
self.handles.append(handle)
46-
47-
# Also capture up_proj activations
48-
if self.has_up_proj:
70+
71+
# Hooks capturing inputs to the activation function
72+
if Hook.ACT in hooks and Hook.ACT in self.hooks_available:
73+
handle = self._register_act_hook(i, layer)
74+
if handle is not None:
75+
self.handles.append(handle)
76+
77+
# Hooks capturing inputs to the down projection
78+
if Hook.UP in hooks and Hook.UP in self.hooks_available:
4979
handle = self._register_up_hook(i, layer)
5080
if handle is not None:
5181
self.handles.append(handle)
82+
83+
# Hooks capturing the final MLP output
84+
if Hook.OUT in hooks and Hook.OUT in self.hooks_available:
85+
handle = self._register_out_hook(i, layer)
86+
if handle is not None:
87+
self.handles.append(handle)
88+
5289

5390
def remove_hooks(self):
5491
"""Remove all registered hooks."""
@@ -59,91 +96,3 @@ def remove_hooks(self):
5996
def clear_captures(self):
6097
"""Clear captured activations."""
6198
self.mlp_activations = {}
62-
63-
64-
65-
class ActivationCaptureDefault(ActivationCapture):
66-
"""Helper class to capture activations from model layers."""
67-
has_gate_proj: bool = True
68-
has_up_proj: bool = True
69-
70-
def get_layers(self):
71-
return self.model.get_decoder().layers
72-
73-
def _create_mlp_hook(self, layer_idx, proj_type):
74-
def hook(module, input, output):
75-
key = f"{layer_idx}_{proj_type}"
76-
# Just detach, don't clone or move to CPU yet
77-
self.mlp_activations[key] = output.clone().detach()
78-
return output
79-
return hook
80-
81-
def _register_gate_hook(self, layer_idx, layer):
82-
handle = layer.mlp.gate_proj.register_forward_hook(
83-
self._create_mlp_hook(layer_idx, 'gate')
84-
)
85-
return handle
86-
87-
def _register_up_hook(self, layer_idx, layer):
88-
handle = layer.mlp.up_proj.register_forward_hook(
89-
self._create_mlp_hook(layer_idx, 'up')
90-
)
91-
return handle
92-
93-
def get_gate_activations(self, layer_idx):
94-
gate_key = f"{layer_idx}_gate"
95-
if gate_key in self.mlp_activations:
96-
gate_act = self.mlp_activations[gate_key]
97-
return F.silu(gate_act)
98-
return None
99-
100-
def get_up_activations(self, layer_idx):
101-
up_key = f"{layer_idx}_up"
102-
if up_key in self.mlp_activations:
103-
up_act = self.mlp_activations[up_key]
104-
return up_act
105-
return None
106-
107-
class ActivationCaptureTraining(ActivationCaptureDefault):
108-
"""Additional Hidden State capture for training dataset generation"""
109-
def __init__(self, model):
110-
super().__init__(model)
111-
self.hidden_states = {}
112-
113-
def _create_hidden_state_hook(self, layer_idx, layer):
114-
def hook(module, args, kwargs, output):
115-
# args[0] is the input hidden states to the layer
116-
if len(args) > 0:
117-
# Just detach, don't clone or move to CPU yet
118-
self.hidden_states[layer_idx] = args[0].clone().detach()
119-
return output
120-
return hook
121-
122-
def _register_hidden_state_hook(self, layer_idx, layer):
123-
handle = layer.register_forward_hook(
124-
self._create_hidden_state_hook(layer_idx, layer),
125-
with_kwargs=True
126-
)
127-
return handle
128-
129-
@override
130-
def clear_captures(self):
131-
"""Clear captured activations."""
132-
super().clear_captures()
133-
self.hidden_states = {}
134-
135-
@override
136-
def register_hooks(self):
137-
"""Register forward hooks to capture activations."""
138-
# Clear any existing hooks
139-
super().register_hooks()
140-
# Hook into each transformer layer
141-
for i, layer in enumerate(self.get_layers()):
142-
# Capture hidden states before MLP
143-
handle = self._register_hidden_state_hook(i, layer)
144-
if handle is not None:
145-
self.handles.append(handle)
146-
147-
def get_hidden_states(self, layer_idx):
148-
"""Get hidden states for a layer."""
149-
return self.hidden_states[layer_idx]

src/cett.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
2+
3+
import torch
4+
5+
from src.activation_capture import ActivationCapture, Hook
6+
7+
def calculate_threshold_one_token(neuron_outputs, cett_target, n_quantiles=1000):
8+
norms = neuron_outputs.norm(dim=0)
9+
quantiles = norms.quantile(torch.linspace(0,1,n_quantiles))
10+
tot_norm = neuron_outputs.sum(dim=1).norm()
11+
12+
def CETT(threshold):
13+
threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=1).norm()
14+
return threshold_norm / tot_norm
15+
16+
left = 0
17+
right = quantiles.size(0)
18+
threshold = 0
19+
while left < right:
20+
mid = (left + right) // 2
21+
cett = CETT(quantiles[mid])
22+
if cett <= cett_target:
23+
left = mid + 1
24+
threshold = quantiles[mid]
25+
else:
26+
right = mid - 1
27+
return threshold
28+
29+
30+
def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=500):
31+
model.activation_capture = model.ACTIVATION_CAPTURE(model)
32+
model.activation_capture.register_hooks(hooks=[Hook.UP])
33+
34+
thresholds = []
35+
36+
with torch.no_grad():
37+
for batch in dataloader:
38+
input_ids = batch["input_ids"]
39+
attention_mask = batch["attention_mask"]
40+
41+
model.activation_capture.clear_captures()
42+
43+
_ = model(input_ids=input_ids, attention_mask=attention_mask)
44+
45+
activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx]
46+
activations = activations.view(-1, activations.size(-1))
47+
48+
for i in range(activations.size(0)):
49+
neuron_outputs = activations[i] * model.model.layers[0].mlp.down_proj.weight
50+
threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles)
51+
thresholds.append(threshold)
52+
53+
return sum(thresholds)/len(thresholds)
54+

src/modeling_skip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from transformers.utils.import_utils import is_torch_flex_attn_available
2121

2222
from sparse_transformers import WeightCache, sparse_mlp_forward
23-
from src.activation_capture import ActivationCaptureDefault
23+
from src.activation_capture import ActivationCapture
2424

2525
if is_torch_flex_attn_available():
2626
from torch.nn.attention.flex_attention import BlockMask
@@ -352,7 +352,7 @@ def forward(
352352

353353

354354
def build_skip_connection_model_for_causal_lm(pretrained_model_class: type[PreTrainedModel], base_model_class: type[PreTrainedModel]):
355-
ACTIVATION_CAPTURE = ActivationCaptureDefault
355+
ACTIVATION_CAPTURE = ActivationCapture
356356

357357
class SkipConnectionModelForCausalLM(pretrained_model_class, GenerationMixin):
358358
_tied_weights_keys = ["lm_head.weight"]

0 commit comments

Comments
 (0)