Skip to content

Commit ca57c3b

Browse files
authored
Activation Capture refactor (#46)
* Refactor activation capture and fix issues with generate_dataset. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Add simple script to plot sparsities Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Activation capture code for phi3 Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Delete measure contextual sparsity and replace with updated version. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> --------- Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent 3d2ac8d commit ca57c3b

File tree

8 files changed

+298
-506
lines changed

8 files changed

+298
-506
lines changed

generate_dataset.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@
4949
from datasets import load_dataset
5050
from torch.utils.data import DataLoader as TorchDataLoader
5151
from tqdm import tqdm
52-
from src.activation_capture import ActivationCapture
52+
from src.activation_capture import ACTIVATION_CAPTURE, ActivationCapture
5353
import csv
5454
import glob
55-
from src.predictor_trainer import get_sample_by_index
55+
from src.trainer import get_sample_by_index
5656
# Setup logging
5757
logging.basicConfig(level=logging.INFO)
5858
logger = logging.getLogger(__name__)
@@ -149,7 +149,7 @@ def process_batch(
149149
hidden_states_dict[layer_idx].append(hidden_state)
150150

151151
# Get last token's MLP activations
152-
mlp_activation = capture.get_mlp_activations(layer_idx)
152+
mlp_activation = capture.get_gate_activations(layer_idx)
153153
if mlp_activation is not None:
154154
mlp_act = mlp_activation[batch_idx,-1,:].cpu().numpy().astype(np.float32)
155155
mlp_activations_dict[layer_idx].append(mlp_act)
@@ -196,14 +196,15 @@ def generate_dataset(
196196

197197
model.eval()
198198

199+
# Setup activation capture
200+
capture_cls = ACTIVATION_CAPTURE[model.config.model_type]
201+
capture = capture_cls()
202+
capture.register_hooks(model)
203+
199204
# Get model dimensions
200205
hidden_dim = model.config.hidden_size
201206
intermediate_dim = model.config.intermediate_size
202-
num_layers = len(model.model.layers)
203-
204-
# Setup activation capture
205-
capture = ActivationCapture()
206-
capture.register_hooks(model)
207+
num_layers = len(capture.get_layers(model))
207208

208209
# Load dataset
209210
logger.info(f"Loading dataset: {dataset_name}")

measure_contextual_sparsity.py

Lines changed: 120 additions & 460 deletions
Large diffs are not rendered by default.

src/activation_capture.py

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,67 @@
11
import torch.nn.functional as F
2+
from abc import ABC, abstractmethod
23

34

4-
class ActivationCapture:
5+
class ActivationCapture(ABC):
56
"""Helper class to capture activations from model layers."""
7+
has_gate_proj: bool
8+
has_up_proj: bool
69

710
def __init__(self):
811
self.hidden_states = {}
912
self.mlp_activations = {}
1013
self.handles = []
11-
14+
15+
@abstractmethod
16+
def _register_gate_hook(self, layer_idx, layer):
17+
pass
18+
19+
@abstractmethod
20+
def _register_up_hook(self, layer_idx, layer):
21+
pass
22+
23+
@abstractmethod
24+
def get_layers(self, model):
25+
pass
26+
27+
def _register_hidden_state_hook(self, layer_idx, layer):
28+
def hook(module, args, kwargs, output):
29+
# args[0] is the input hidden states to the layer
30+
if len(args) > 0:
31+
# Just detach, don't clone or move to CPU yet
32+
self.hidden_states[layer_idx] = args[0].detach()
33+
return output
34+
handle = layer.register_forward_hook(
35+
hook,
36+
with_kwargs=True
37+
)
38+
return handle
39+
1240
def register_hooks(self, model):
1341
"""Register forward hooks to capture activations."""
1442
# Clear any existing hooks
1543
self.remove_hooks()
1644

1745
# Hook into each transformer layer
18-
for i, layer in enumerate(model.model.layers):
46+
for i, layer in enumerate(self.get_layers(model)):
1947

2048
# Capture hidden states before MLP
21-
handle = layer.register_forward_hook(
22-
self._create_hidden_state_hook(i),
23-
with_kwargs=True
24-
)
25-
self.handles.append(handle)
49+
handle = self._register_hidden_state_hook(i, layer)
50+
if handle is not None:
51+
self.handles.append(handle)
2652

2753
# Capture MLP gate activations (after activation function)
28-
if hasattr(layer.mlp, 'gate_proj'):
29-
handle = layer.mlp.gate_proj.register_forward_hook(
30-
self._create_mlp_hook(i, 'gate')
31-
)
32-
self.handles.append(handle)
54+
if self.has_gate_proj:
55+
handle = self._register_gate_hook(i, layer)
56+
if handle is not None:
57+
self.handles.append(handle)
3358

3459
# Also capture up_proj activations
35-
if hasattr(layer.mlp, 'up_proj'):
36-
handle = layer.mlp.up_proj.register_forward_hook(
37-
self._create_mlp_hook(i, 'up')
38-
)
39-
self.handles.append(handle)
40-
41-
def _create_hidden_state_hook(self, layer_idx):
42-
def hook(module, args, kwargs, output):
43-
# args[0] is the input hidden states to the layer
44-
if len(args) > 0:
45-
# Just detach, don't clone or move to CPU yet
46-
self.hidden_states[layer_idx] = args[0].detach()
47-
return output
48-
return hook
49-
50-
def _create_mlp_hook(self, layer_idx, proj_type):
51-
def hook(module, input, output):
52-
key = f"{layer_idx}_{proj_type}"
53-
# Just detach, don't clone or move to CPU yet
54-
self.mlp_activations[key] = output.detach()
55-
return output
56-
return hook
57-
60+
if self.has_up_proj:
61+
handle = self._register_up_hook(i, layer)
62+
if handle is not None:
63+
self.handles.append(handle)
64+
5865
def remove_hooks(self):
5966
"""Remove all registered hooks."""
6067
for handle in self.handles:
@@ -65,7 +72,46 @@ def clear_captures(self):
6572
"""Clear captured activations."""
6673
self.hidden_states = {}
6774
self.mlp_activations = {}
68-
75+
76+
@abstractmethod
77+
def get_mlp_activations(self, layer_idx):
78+
"""Get combined MLP activations for a layer."""
79+
pass
80+
81+
@abstractmethod
82+
def get_gate_activations(self, layer_idx):
83+
"""Get combined MLP activations for a layer."""
84+
return
85+
86+
87+
class ActivationCaptureDefault(ActivationCapture):
88+
"""Helper class to capture activations from model layers."""
89+
has_gate_proj: bool = True
90+
has_up_proj: bool = True
91+
92+
def get_layers(self, model):
93+
return model.model.layers
94+
95+
def _create_mlp_hook(self, layer_idx, proj_type):
96+
def hook(module, input, output):
97+
key = f"{layer_idx}_{proj_type}"
98+
# Just detach, don't clone or move to CPU yet
99+
self.mlp_activations[key] = output.detach()
100+
return output
101+
return hook
102+
103+
def _register_gate_hook(self, layer_idx, layer):
104+
handle = layer.mlp.gate_proj.register_forward_hook(
105+
self._create_mlp_hook(layer_idx, 'gate')
106+
)
107+
return handle
108+
109+
def _register_up_hook(self, layer_idx, layer):
110+
handle = layer.mlp.up_proj.register_forward_hook(
111+
self._create_mlp_hook(layer_idx, 'up')
112+
)
113+
return handle
114+
69115
def get_mlp_activations(self, layer_idx):
70116
"""Get combined MLP activations for a layer."""
71117
gate_key = f"{layer_idx}_gate"
@@ -80,4 +126,18 @@ def get_mlp_activations(self, layer_idx):
80126
gated_act = F.silu(gate_act) * up_act
81127
return gated_act
82128

83-
return None
129+
return None
130+
131+
def get_gate_activations(self, layer_idx):
132+
"""Get combined MLP activations for a layer."""
133+
gate_key = f"{layer_idx}_gate"
134+
if gate_key in self.mlp_activations:
135+
gate_act = self.mlp_activations[gate_key]
136+
return F.silu(gate_act)
137+
return None
138+
139+
140+
ACTIVATION_CAPTURE = {}
141+
142+
def register_activation_capture(model_name, activation_capture):
143+
ACTIVATION_CAPTURE[model_name] = activation_capture

src/models/llama/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@
77
AutoConfig.register("llama-skip", LlamaSkipConnectionConfig)
88
AutoModelForCausalLM.register(LlamaSkipConnectionConfig, LlamaSkipConnectionForCausalLM)
99

10+
from src.activation_capture import register_activation_capture, ActivationCaptureDefault
11+
register_activation_capture('llama', ActivationCaptureDefault)
12+
1013
__all__ = [configuration_llama_skip, modelling_llama_skip]

src/models/mistral/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@
77
AutoConfig.register("mistral-skip", MistralSkipConnectionConfig)
88
AutoModelForCausalLM.register(MistralSkipConnectionConfig, MistralSkipConnectionForCausalLM)
99

10+
from src.activation_capture import register_activation_capture, ActivationCaptureDefault
11+
register_activation_capture('mistral', ActivationCaptureDefault)
12+
1013
__all__ = [configuration_mistral_skip, modelling_mistral_skip]

src/models/phi3/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,8 @@
77
AutoConfig.register("phi3-skip", Phi3SkipConnectionConfig)
88
AutoModelForCausalLM.register(Phi3SkipConnectionConfig, Phi3SkipConnectionForCausalLM)
99

10+
from .activation_capture_phi import ActivationCapturePhi3
11+
from src.activation_capture import register_activation_capture
12+
register_activation_capture('phi3', ActivationCapturePhi3)
13+
1014
__all__ = [configuration_phi_skip, modelling_phi_skip]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from src.activation_capture import ActivationCapture
2+
import torch.nn.functional as F
3+
4+
5+
6+
class ActivationCapturePhi3(ActivationCapture):
7+
"""Helper class to capture activations from model layers."""
8+
has_gate_proj: bool = True
9+
has_up_proj: bool = True
10+
11+
def get_layers(self, model):
12+
return model.model.layers
13+
14+
def _register_gate_hook(self, layer_idx, layer):
15+
def hook(module, input, output):
16+
key1 = f"{layer_idx}_{'gate'}"
17+
key2 = f"{layer_idx}_{'up'}"
18+
# Just detach, don't clone or move to CPU yet
19+
gate_outputs, up_outputs = output.chunk(2, dim=1)
20+
self.mlp_activations[key1] = gate_outputs.detach()
21+
self.mlp_activations[key2] = up_outputs.detach()
22+
return output
23+
handle = layer.mlp.gate_up_proj.register_forward_hook(hook)
24+
return handle
25+
26+
def _register_up_hook(self, layer_idx, layer):
27+
def hook(module, input, output):
28+
key = f"{layer_idx}_{'up'}"
29+
# Just detach, don't clone or move to CPU yet
30+
up_outputs = output.chunk(2, dim=1)[1]
31+
self.mlp_activations[key] = up_outputs.detach()
32+
return output
33+
handle = layer.mlp.gate_up_proj.register_forward_hook(hook)
34+
return handle
35+
36+
def get_gate_activations(self, layer_idx):
37+
"""Get combined MLP activations for a layer."""
38+
gate_key = f"{layer_idx}_gate"
39+
if gate_key in self.mlp_activations:
40+
gate_act = self.mlp_activations[gate_key]
41+
return F.silu(gate_act)
42+
return None
43+
44+
def get_mlp_activations(self, layer_idx):
45+
"""Get combined MLP activations for a layer."""
46+
gate_key = f"{layer_idx}_gate"
47+
up_key = f"{layer_idx}_up"
48+
49+
if gate_key in self.mlp_activations and up_key in self.mlp_activations:
50+
# Compute gated activations: gate(x) * up(x)
51+
gate_act = self.mlp_activations[gate_key]
52+
up_act = self.mlp_activations[up_key]
53+
54+
# Apply SwiGLU activation: silu(gate) * up
55+
gated_act = F.silu(gate_act) * up_act
56+
return gated_act
57+
58+
return None

src/models/qwen2/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,7 @@
77
AutoConfig.register("qwen2-skip", Qwen2SkipConnectionConfig)
88
AutoModelForCausalLM.register(Qwen2SkipConnectionConfig, Qwen2SkipConnectionForCausalLM)
99

10+
from src.activation_capture import register_activation_capture, ActivationCaptureDefault
11+
register_activation_capture('qwen2', ActivationCaptureDefault)
12+
1013
__all__ = [configuration_qwen_skip, modelling_qwen_skip]

0 commit comments

Comments
 (0)