44"""
55from .model_ema import ModelEma
66import torch
7-
7+ import fnmatch
88
99def unwrap_model (model ):
1010 if isinstance (model , ModelEma ):
@@ -23,58 +23,62 @@ def avg_sq_ch_mean(model, input, output):
2323
2424
2525def avg_ch_var (model , input , output ):
26+ "calculate average channel variance of output activations"
27+ return torch .mean (output .var (axis = [0 ,2 ,3 ])).item ()\
28+
29+
30+ def avg_ch_var_residual (model , input , output ):
2631 "calculate average channel variance of output activations"
2732 return torch .mean (output .var (axis = [0 ,2 ,3 ])).item ()
2833
2934
3035class ActivationStatsHook :
31- """Iterates through each of `model`'s modules and if module's class name
32- is present in `layer_names` then registers `hook_fns` inside that module
33- and stores activation stats inside `self.stats` .
36+ """Iterates through each of `model`'s modules and matches modules using unix pattern
37+ matching based on `layer_name` and `layer_type`. If there is match, this class adds
38+ creates a hook using `hook_fn` and adds it to the module .
3439
3540 Arguments:
3641 model (nn.Module): model from which we will extract the activation stats
37- layer_names (List[ str] ): The layer name to look for to register forward
38- hook. Example, `BasicBlock`, `Bottleneck`
42+ layer_names (str): The layer name to look for to register forward
43+ hook. Example, 'stem', 'stages'
3944 hook_fns (List[Callable]): List of hook functions to be registered at every
4045 module in `layer_names`.
4146
4247 Inspiration from https://docs.fast.ai/callback.hook.html.
4348 """
4449
45- def __init__ (self , model , layer_names , hook_fns = [ avg_sq_ch_mean , avg_ch_var ] ):
50+ def __init__ (self , model , hook_fn_locs , hook_fns ):
4651 self .model = model
47- self .layer_names = layer_names
52+ self .hook_fn_locs = hook_fn_locs
4853 self .hook_fns = hook_fns
4954 self .stats = dict ((hook_fn .__name__ , []) for hook_fn in hook_fns )
50- for hook_fn in hook_fns :
51- self .register_hook (layer_names , hook_fn )
55+ for hook_fn_loc , hook_fn in zip ( hook_fn_locs , hook_fns ) :
56+ self .register_hook (hook_fn_loc , hook_fn )
5257
5358 def _create_hook (self , hook_fn ):
5459 def append_activation_stats (module , input , output ):
5560 out = hook_fn (module , input , output )
5661 self .stats [hook_fn .__name__ ].append (out )
5762 return append_activation_stats
5863
59- def register_hook (self , layer_names , hook_fn ):
60- for layer in self .model .modules ():
61- layer_name = layer .__class__ .__name__
62- if layer_name not in layer_names :
64+ def register_hook (self , hook_fn_loc , hook_fn ):
65+ for name , module in self .model .named_modules ():
66+ if not fnmatch .fnmatch (name , hook_fn_loc ):
6367 continue
64- layer .register_forward_hook (self ._create_hook (hook_fn ))
68+ module .register_forward_hook (self ._create_hook (hook_fn ))
6569
6670
6771def extract_spp_stats (model ,
68- layer_names ,
69- hook_fns = [ avg_sq_ch_mean , avg_ch_var ] ,
72+ hook_fn_locs ,
73+ hook_fns ,
7074 input_shape = [8 , 3 , 224 , 224 ]):
7175 """Extract average square channel mean and variance of activations during
7276 forward pass to plot Signal Propogation Plots (SPP).
7377
7478 Paper: https://arxiv.org/abs/2101.08692
7579 """
7680 x = torch .normal (0. , 1. , input_shape )
77- hook = ActivationStatsHook (model , layer_names , hook_fns )
81+ hook = ActivationStatsHook (model , hook_fn_locs = hook_fn_locs , hook_fns = hook_fns )
7882 _ = model (x )
7983 return hook .stats
8084
0 commit comments