1- from typing_extensions import override
2- import torch .nn .functional as F
3- from abc import ABC , abstractmethod
41
2+ from enum import Enum
3+ from typing import List
54
6- class ActivationCapture (ABC ):
5+ class Hook (Enum ):
6+ IN = "IN"
7+ ACT = "ACT"
8+ UP = "UP"
9+ OUT = "OUT"
10+
11+
12+ class ActivationCapture ():
713 """Helper class to capture activations from model layers."""
8- has_gate_proj : bool
9- has_up_proj : bool
14+ hooks_available : List [Hook ]
1015
1116 def __init__ (self , model ):
1217 self .model = model
13- self .mlp_activations = {}
18+ self .mlp_activations = {
19+ hook : {} for hook in self .hooks_available
20+ }
1421 self .handles = []
1522
16- @abstractmethod
17- def _register_gate_hook (self , layer_idx , layer ):
18- pass
23+ def _register_in_hook (self , layer_idx , layer ):
24+ def hook (module , input , output ):
25+ # Just detach, don't clone or move to CPU yet
26+ self .mlp_activations [Hook .IN ][layer_idx ] = input [0 ].clone ().detach ()
27+ return output
28+ handle = layer .mlp .register_forward_hook (hook )
29+ return handle
30+
31+ def _register_act_hook (self , layer_idx , layer ):
32+ def hook (module , input , output ):
33+ # Just detach, don't clone or move to CPU yet
34+ self .mlp_activations [Hook .ACT ][layer_idx ] = input [0 ].clone ().detach ()
35+ return output
36+ handle = layer .mlp .act_fn .register_forward_hook (hook )
37+ return handle
1938
20- @abstractmethod
2139 def _register_up_hook (self , layer_idx , layer ):
22- pass
40+ def hook (module , input , output ):
41+ # Just detach, don't clone or move to CPU yet
42+ self .mlp_activations [Hook .UP ][layer_idx ] = input [0 ].clone ().detach ()
43+ return output
44+ handle = layer .mlp .down_proj .register_forward_hook (hook )
45+ return handle
46+
47+ def _register_out_hook (self , layer_idx , layer ):
48+ def hook (module , input , output ):
49+ # Just detach, don't clone or move to CPU yet
50+ self .mlp_activations [Hook .OUT ][layer_idx ] = output .clone ().detach ()
51+ return output
52+ handle = layer .mlp .register_forward_hook (hook )
53+ return handle
2354
24- @abstractmethod
2555 def get_layers (self ):
26- pass
27-
28-
29- @abstractmethod
30- def get_gate_activations (self , layer_idx ):
31- """Get combined MLP activations for a layer."""
32- pass
56+ return self .model .get_decoder ().layers
3357
34- def register_hooks (self ):
58+ def register_hooks (self , hooks = ( Hook . ACT , Hook . UP , Hook . OUT ) ):
3559 """Register forward hooks to capture activations."""
3660 # Clear any existing hooks
3761 self .remove_hooks ()
3862
3963 # Hook into each transformer layer
40- for i , layer in enumerate (self .get_layers ()):
41- # Capture MLP gate activations (after activation function)
42- if self .has_gate_proj :
43- handle = self ._register_gate_hook (i , layer )
64+ for i , layer in enumerate (self .get_layers ()):
65+ # Hooks capturing inputs to the MLP layer
66+ if Hook . IN in hooks and Hook . IN in self .hooks_available :
67+ handle = self ._register_in_hook (i , layer )
4468 if handle is not None :
4569 self .handles .append (handle )
46-
47- # Also capture up_proj activations
48- if self .has_up_proj :
70+
71+ # Hooks capturing inputs to the activation function
72+ if Hook .ACT in hooks and Hook .ACT in self .hooks_available :
73+ handle = self ._register_act_hook (i , layer )
74+ if handle is not None :
75+ self .handles .append (handle )
76+
77+ # Hooks capturing inputs to the down projection
78+ if Hook .UP in hooks and Hook .UP in self .hooks_available :
4979 handle = self ._register_up_hook (i , layer )
5080 if handle is not None :
5181 self .handles .append (handle )
82+
83+ # Hooks capturing the final MLP output
84+ if Hook .OUT in hooks and Hook .OUT in self .hooks_available :
85+ handle = self ._register_out_hook (i , layer )
86+ if handle is not None :
87+ self .handles .append (handle )
88+
5289
5390 def remove_hooks (self ):
5491 """Remove all registered hooks."""
@@ -59,91 +96,3 @@ def remove_hooks(self):
5996 def clear_captures (self ):
6097 """Clear captured activations."""
6198 self .mlp_activations = {}
62-
63-
64-
65- class ActivationCaptureDefault (ActivationCapture ):
66- """Helper class to capture activations from model layers."""
67- has_gate_proj : bool = True
68- has_up_proj : bool = True
69-
70- def get_layers (self ):
71- return self .model .get_decoder ().layers
72-
73- def _create_mlp_hook (self , layer_idx , proj_type ):
74- def hook (module , input , output ):
75- key = f"{ layer_idx } _{ proj_type } "
76- # Just detach, don't clone or move to CPU yet
77- self .mlp_activations [key ] = output .clone ().detach ()
78- return output
79- return hook
80-
81- def _register_gate_hook (self , layer_idx , layer ):
82- handle = layer .mlp .gate_proj .register_forward_hook (
83- self ._create_mlp_hook (layer_idx , 'gate' )
84- )
85- return handle
86-
87- def _register_up_hook (self , layer_idx , layer ):
88- handle = layer .mlp .up_proj .register_forward_hook (
89- self ._create_mlp_hook (layer_idx , 'up' )
90- )
91- return handle
92-
93- def get_gate_activations (self , layer_idx ):
94- gate_key = f"{ layer_idx } _gate"
95- if gate_key in self .mlp_activations :
96- gate_act = self .mlp_activations [gate_key ]
97- return F .silu (gate_act )
98- return None
99-
100- def get_up_activations (self , layer_idx ):
101- up_key = f"{ layer_idx } _up"
102- if up_key in self .mlp_activations :
103- up_act = self .mlp_activations [up_key ]
104- return up_act
105- return None
106-
107- class ActivationCaptureTraining (ActivationCaptureDefault ):
108- """Additional Hidden State capture for training dataset generation"""
109- def __init__ (self , model ):
110- super ().__init__ (model )
111- self .hidden_states = {}
112-
113- def _create_hidden_state_hook (self , layer_idx , layer ):
114- def hook (module , args , kwargs , output ):
115- # args[0] is the input hidden states to the layer
116- if len (args ) > 0 :
117- # Just detach, don't clone or move to CPU yet
118- self .hidden_states [layer_idx ] = args [0 ].clone ().detach ()
119- return output
120- return hook
121-
122- def _register_hidden_state_hook (self , layer_idx , layer ):
123- handle = layer .register_forward_hook (
124- self ._create_hidden_state_hook (layer_idx , layer ),
125- with_kwargs = True
126- )
127- return handle
128-
129- @override
130- def clear_captures (self ):
131- """Clear captured activations."""
132- super ().clear_captures ()
133- self .hidden_states = {}
134-
135- @override
136- def register_hooks (self ):
137- """Register forward hooks to capture activations."""
138- # Clear any existing hooks
139- super ().register_hooks ()
140- # Hook into each transformer layer
141- for i , layer in enumerate (self .get_layers ()):
142- # Capture hidden states before MLP
143- handle = self ._register_hidden_state_hook (i , layer )
144- if handle is not None :
145- self .handles .append (handle )
146-
147- def get_hidden_states (self , layer_idx ):
148- """Get hidden states for a layer."""
149- return self .hidden_states [layer_idx ]
0 commit comments