Skip to content

Commit d2252aa

Browse files
authored
Add Gemma-3n (#50)
* Initial commit for Gemma-3n. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Fixes to Gemma implementation. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Add Gemma to __init__ Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Updated requirements for Gemma. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Minor changes to benchmark and evaluate. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Default for layer_idx set to None for set_mlp_train and set_mlp_inference in modeling_skip Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Updated activation capture to work with refactor Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Updated activation capture to work with refactor Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Activation capture fix. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> --------- Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent 491a7cc commit d2252aa

15 files changed

+682
-32
lines changed

benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def parse_args() -> argparse.Namespace:
2323
help='Verbose output')
2424
parser.add_argument('--config', type=str, default='configs/llama_skip_causal_3b.json',
2525
help='Config file')
26+
parser.add_argument('--max_response_length', type=int, default=-1,
27+
help='Maximum response tokens per prompt.')
2628
return parser.parse_args()
2729

2830

@@ -400,6 +402,9 @@ def main():
400402

401403
# Get test prompts
402404
test_prompts = get_diverse_test_prompts()
405+
if args.max_response_length > 0:
406+
for prompt in test_prompts:
407+
prompt['max_tokens'] = min(prompt['max_tokens'], args.max_response_length)
403408

404409
print(f"\n🎯 Running comprehensive benchmark with {len(test_prompts)} diverse prompts...")
405410
print(f"📝 Test prompts: {[p['description'] for p in test_prompts]}")
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
{
2+
"_name_or_path": "google/gemma-3n-E2B",
3+
"sparsity": 0.3,
4+
"architectures": [
5+
"Gemma3nSkipConnectionForCausalLM"
6+
],
7+
"activation_sparsity_pattern": [
8+
0.95,
9+
0.95,
10+
0.95,
11+
0.95,
12+
0.95,
13+
0.95,
14+
0.95,
15+
0.95,
16+
0.95,
17+
0.95,
18+
0.0,
19+
0.0,
20+
0.0,
21+
0.0,
22+
0.0,
23+
0.0,
24+
0.0,
25+
0.0,
26+
0.0,
27+
0.0,
28+
0.0,
29+
0.0,
30+
0.0,
31+
0.0,
32+
0.0,
33+
0.0,
34+
0.0,
35+
0.0,
36+
0.0,
37+
0.0
38+
],
39+
"altup_active_idx": 0,
40+
"altup_coef_clip": 120.0,
41+
"altup_correct_scale": true,
42+
"altup_lr_multiplier": 1.0,
43+
"altup_num_inputs": 4,
44+
"attention_bias": false,
45+
"attention_dropout": 0.0,
46+
"final_logit_softcapping": 30.0,
47+
"head_dim": 256,
48+
"hidden_activation": "gelu_pytorch_tanh",
49+
"hidden_size": 2048,
50+
"hidden_size_per_layer_input": 256,
51+
"initializer_range": 0.02,
52+
"intermediate_size": 8192,
53+
"laurel_rank": 64,
54+
"layer_types": [
55+
"sliding_attention",
56+
"sliding_attention",
57+
"sliding_attention",
58+
"sliding_attention",
59+
"full_attention",
60+
"sliding_attention",
61+
"sliding_attention",
62+
"sliding_attention",
63+
"sliding_attention",
64+
"full_attention",
65+
"sliding_attention",
66+
"sliding_attention",
67+
"sliding_attention",
68+
"sliding_attention",
69+
"full_attention",
70+
"sliding_attention",
71+
"sliding_attention",
72+
"sliding_attention",
73+
"sliding_attention",
74+
"full_attention",
75+
"sliding_attention",
76+
"sliding_attention",
77+
"sliding_attention",
78+
"sliding_attention",
79+
"full_attention",
80+
"sliding_attention",
81+
"sliding_attention",
82+
"sliding_attention",
83+
"sliding_attention",
84+
"full_attention"
85+
],
86+
"max_position_embeddings": 32768,
87+
"model_type": "gemma3n-skip",
88+
"num_attention_heads": 8,
89+
"num_hidden_layers": 30,
90+
"num_key_value_heads": 2,
91+
"num_kv_shared_layers": 10,
92+
"query_pre_attn_scalar": 256,
93+
"rms_norm_eps": 1e-06,
94+
"rope_local_base_freq": 10000.0,
95+
"rope_scaling": null,
96+
"rope_theta": 1000000.0,
97+
"sliding_window": 512,
98+
"torch_dtype": "bfloat16",
99+
"transformers_version": "4.53.0.dev0",
100+
"use_cache": true,
101+
"vocab_size": 262400,
102+
"vocab_size_per_layer_input": 262144
103+
}

evaluate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def main():
6060

6161
wrapped_model = HFLM(
6262
pretrained=model,
63+
backend="causal",
6364
batch_size=args.batch_size,
64-
device=device
65+
device=device,
6566
)
6667

6768
logging.info("Beginning evaluation...")

requirements.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Core ML/AI packages
22
# conda install pytorch==2.5.0 torchvision==0.20.0 torchaudio==2.5.0 pytorch-cuda=12.4 -c pytorch -c nvidia
3-
transformers==4.52.4
3+
transformers==4.53.0
44
numpy
55
psutil
66
optimum
@@ -11,4 +11,6 @@ datasets
1111
sentencepiece
1212
protobuf
1313
wandb
14-
ninja
14+
ninja
15+
timm
16+
pillow

src/activation_capture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class ActivationCaptureDefault(ActivationCapture):
6868
has_up_proj: bool = True
6969

7070
def get_layers(self):
71-
return self.model.model.layers
71+
return self.model.get_decoder().layers
7272

7373
def _create_mlp_hook(self, layer_idx, proj_type):
7474
def hook(module, input, output):

src/modeling_skip.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5252
return self.up(self.down(x))
5353

5454
class SkipMLP(nn.Module):
55-
def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bias: bool = False):
55+
def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bias: bool = False, act_fn="silu"):
5656
super().__init__()
5757
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
5858
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
5959
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
6060
self.sparsity = sparsity
6161
self.hidden_size = hidden_size
6262
self.intermediate_size = intermediate_size
63+
self.act_fn = act_fn
6364

6465
# Initialize mask but defer WeightCache creation until post_init
6566
self.init_mask = torch.ones(intermediate_size, dtype=torch.bool)
@@ -101,7 +102,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
101102
self.weight_cache.get_active_down_weight(), # type: ignore
102103
self.down_proj_buffer,
103104
self.combined_proj_buffer,
104-
"silu"
105+
self.act_fn
105106
)
106107
return out
107108

@@ -110,16 +111,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
110111
class SkipDecoderLayer(ABC, GradientCheckpointingLayer):
111112
def __init__(self, config: PretrainedConfig, layer_idx: int):
112113
super().__init__()
114+
self.config = config
113115
self.hidden_size = config.hidden_size
114116
self.layer_idx = layer_idx
115117
self.sparsity = config.sparsity
116118

117119
self._init_components(config, layer_idx)
118120

119-
self.lora_size = int(config.intermediate_size * 0.04)
121+
intermediate_size = config.intermediate_size[layer_idx] if isinstance(config.intermediate_size, list) \
122+
else config.intermediate_size
123+
self.lora_size = int(intermediate_size * 0.04)
120124
self.mlp_lora_proj = FastLoRAProjection(
121125
config.hidden_size,
122-
config.intermediate_size,
126+
intermediate_size,
123127
self.lora_size
124128
)
125129

@@ -128,20 +132,20 @@ def __init__(self, config: PretrainedConfig, layer_idx: int):
128132
# Only initialize predictor training components if explicitly enabled
129133
if self.is_training_config:
130134
# Standard MLP for ground truth collection during training
131-
self._set_mlp_train(config)
135+
self._set_mlp_train(config, layer_idx)
132136
else:
133-
self._set_mlp_inference(config)
137+
self._set_mlp_inference(config, layer_idx)
134138

135139
@abstractmethod
136140
def _init_components(self, config, layer_idx):
137141
pass
138142

139143
@abstractmethod
140-
def _set_mlp_train(self, config):
144+
def _set_mlp_train(self, config, layer_idx=None):
141145
pass
142146

143147
@abstractmethod
144-
def _set_mlp_inference(self, config):
148+
def _set_mlp_inference(self, config, layer_idx=None):
145149
pass
146150

147151
@property
@@ -199,6 +203,12 @@ def forward(
199203
return outputs
200204

201205

206+
'''
207+
Note:
208+
Now that the intermediate losses have been removed, almost all the actual changes are confined to SkipDecoderLayer and Skip MLP.
209+
SkipConnectionModel/SkipConnectionForCausalLM may not even be necessary. It's possible at some point in the future we might want
210+
to attempt a refactor here to simply extend from e.g. LlamaModel and just override the initialization.
211+
'''
202212
def build_skip_connection_model(pretrained_model_class: type[PreTrainedModel]) -> type[PreTrainedModel]:
203213
class SkipConnectionModel(ABC, pretrained_model_class):
204214
def __init__(self, config: PretrainedConfig):
@@ -336,17 +346,7 @@ def forward(
336346
hidden_states=all_hidden_states, # type: ignore
337347
attentions=all_self_attns,
338348
)
339-
340-
@abstractmethod
341-
def _update_causal_mask(
342-
self,
343-
attention_mask: Union[torch.Tensor, "BlockMask"], # type: ignore
344-
input_tensor: torch.Tensor,
345-
cache_position: torch.Tensor,
346-
past_key_values: Cache,
347-
output_attentions: bool = False,
348-
):
349-
pass
349+
350350
return SkipConnectionModel
351351

352352

src/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from . import qwen2
33
from . import mistral
44
from . import phi3
5+
from . import gemma3n
56
# from . import dia

src/models/gemma3n/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from . import configuration_gemma_skip
2+
from . import modelling_gemma_skip
3+
4+
from transformers import AutoConfig, AutoModelForCausalLM
5+
from .configuration_gemma_skip import Gemma3nSkipConnectionConfig
6+
from .modelling_gemma_skip import Gemma3nSkipConnectionForCausalLM
7+
AutoConfig.register("gemma3n-skip", Gemma3nSkipConnectionConfig)
8+
AutoModelForCausalLM.register(Gemma3nSkipConnectionConfig, Gemma3nSkipConnectionForCausalLM)
9+
10+
__all__ = [configuration_gemma_skip, modelling_gemma_skip]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from src.activation_capture import ActivationCaptureDefault
2+
3+
4+
class ActivationCaptureGemma3n(ActivationCaptureDefault):
5+
"""Helper class to capture activations from model layers."""
6+
7+
def _register_gate_hook(self, layer_idx, layer):
8+
handle = layer.mlp.act_fn.register_forward_hook(
9+
self._create_mlp_hook(layer_idx, 'gate')
10+
)
11+
return handle
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from transformers import Gemma3nTextConfig
2+
from src.configuration_skip import build_skip_config
3+
4+
Gemma3nSkipConnectionConfig = build_skip_config(Gemma3nTextConfig, "gemma3n-skip")

0 commit comments

Comments
 (0)