Skip to content

Commit d70d2b6

Browse files
authored
New models (#38)
* Move modeling_utils into modeling_skip. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Add mistral and update requirements to add sentencepiece and protobuf. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Add mistral and update requirements to add sentencepiece and protobuf. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Add base code for phi3. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Add config files for mistral and phi Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Moved initialization for FastLoRAProjection from run_benchmark to from_pretrained and added functionality for initializing Phi3SkipMLP weights. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent bf955c7 commit d70d2b6

15 files changed

+956
-89
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"_name_or_path": "mistralai/Mistral-7B-Instruct-v0.3",
3+
"sparsity": 0.3,
4+
"architectures": [
5+
"MistralSkipConnectionForCausalLM"
6+
],
7+
"attention_dropout": 0.0,
8+
"bos_token_id": 1,
9+
"eos_token_id": 2,
10+
"hidden_act": "silu",
11+
"hidden_size": 4096,
12+
"initializer_range": 0.02,
13+
"intermediate_size": 14336,
14+
"max_position_embeddings": 32768,
15+
"model_type": "mistral-skip",
16+
"num_attention_heads": 32,
17+
"num_hidden_layers": 32,
18+
"num_key_value_heads": 8,
19+
"rms_norm_eps": 1e-05,
20+
"rope_theta": 1000000.0,
21+
"sliding_window": null,
22+
"tie_word_embeddings": false,
23+
"torch_dtype": "bfloat16",
24+
"transformers_version": "4.42.0.dev0",
25+
"use_cache": true,
26+
"vocab_size": 32768
27+
}

configs/phi3_skip_causal_3.8b.json

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
{
2+
"_name_or_path": "microsoft/Phi-4-mini-instruct",
3+
"sparsity": 0.3,
4+
"architectures": [
5+
"Phi3SkipConnectionForCausalLM"
6+
],
7+
"attention_bias": false,
8+
"attention_dropout": 0.0,
9+
"auto_map": {
10+
"AutoTokenizer": "Xenova/gpt-4o"
11+
},
12+
"bos_token_id": 199999,
13+
"embd_pdrop": 0.0,
14+
"eos_token_id": 199999,
15+
"full_attn_mod": 1,
16+
"hidden_act": "silu",
17+
"hidden_size": 3072,
18+
"initializer_range": 0.02,
19+
"intermediate_size": 8192,
20+
"interpolate_factor": 1,
21+
"lm_head_bias": false,
22+
"max_position_embeddings": 131072,
23+
"mlp_bias": false,
24+
"model_type": "phi3-skip",
25+
"num_attention_heads": 24,
26+
"num_hidden_layers": 32,
27+
"num_key_value_heads": 8,
28+
"original_max_position_embeddings": 4096,
29+
"pad_token_id": 199999,
30+
"partial_rotary_factor": 0.75,
31+
"resid_pdrop": 0.0,
32+
"rms_norm_eps": 1e-05,
33+
"rope_scaling": {
34+
"long_factor": [
35+
1,
36+
1.118320672,
37+
1.250641126,
38+
1.398617824,
39+
1.564103225,
40+
1.74916897,
41+
1.956131817,
42+
2.187582649,
43+
2.446418898,
44+
2.735880826,
45+
3.059592084,
46+
3.421605075,
47+
3.826451687,
48+
4.279200023,
49+
4.785517845,
50+
5.351743533,
51+
5.984965424,
52+
6.693110555,
53+
7.485043894,
54+
8.370679318,
55+
9.36110372,
56+
10.4687158,
57+
11.70738129,
58+
13.09260651,
59+
14.64173252,
60+
16.37415215,
61+
18.31155283,
62+
20.47818807,
63+
22.90118105,
64+
25.61086418,
65+
28.64115884,
66+
32.03,
67+
32.1,
68+
32.13,
69+
32.23,
70+
32.6,
71+
32.61,
72+
32.64,
73+
32.66,
74+
32.7,
75+
32.71,
76+
32.93,
77+
32.97,
78+
33.28,
79+
33.49,
80+
33.5,
81+
44.16,
82+
47.77
83+
],
84+
"short_factor": [
85+
1.0,
86+
1.0,
87+
1.0,
88+
1.0,
89+
1.0,
90+
1.0,
91+
1.0,
92+
1.0,
93+
1.0,
94+
1.0,
95+
1.0,
96+
1.0,
97+
1.0,
98+
1.0,
99+
1.0,
100+
1.0,
101+
1.0,
102+
1.0,
103+
1.0,
104+
1.0,
105+
1.0,
106+
1.0,
107+
1.0,
108+
1.0,
109+
1.0,
110+
1.0,
111+
1.0,
112+
1.0,
113+
1.0,
114+
1.0,
115+
1.0,
116+
1.0,
117+
1.0,
118+
1.0,
119+
1.0,
120+
1.0,
121+
1.0,
122+
1.0,
123+
1.0,
124+
1.0,
125+
1.0,
126+
1.0,
127+
1.0,
128+
1.0,
129+
1.0,
130+
1.0,
131+
1.0,
132+
1.0
133+
],
134+
"type": "longrope"
135+
},
136+
"rope_theta": 10000.0,
137+
"sliding_window": 262144,
138+
"tie_word_embeddings": true,
139+
"torch_dtype": "bfloat16",
140+
"transformers_version": "4.45.0",
141+
"use_cache": true,
142+
"vocab_size": 200064
143+
}
144+

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ attrs
88
scikit-learn
99
accelerate
1010
datasets
11+
sentencepiece
12+
protobuf
1113
#wandb

run_benchmark.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
11-
from src.modeling_utils import FastLoRAProjection
1211
from src.utilities.cuda_utils import GPUMonitor, setup_cuda_debugging
1312
from src.utilities.sys_utils import print_system_info
1413
import src.models # adds models to registry
@@ -407,12 +406,6 @@ def main():
407406

408407
# Always run SkipLLaMA benchmark with HuggingFace
409408
skip_model = AutoModelForCausalLM.from_pretrained(checkpoint, config=config)
410-
for module in skip_model.modules():
411-
if any(hasattr(p, 'is_meta') and p.is_meta for p in module.parameters()) and isinstance(module, FastLoRAProjection):
412-
module = module.to_empty(device="cpu")
413-
with torch.no_grad():
414-
torch.nn.init.xavier_normal_(module.down.weight)
415-
torch.nn.init.zeros_(module.up.weight) # Initialize up projection to zeros for stable training
416409
skip_model.tie_weights()
417410

418411
skip_name = "Skip-%s" % model_name

src/modeling_skip.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,74 @@
3636
approx_topk_threshold
3737
)
3838

39-
from src.modeling_utils import (
40-
FastLoRAProjection, BaseModelOutputWithPastAndPredictorLoss
41-
)
42-
4339
logger = logging.get_logger(__name__)
4440

41+
@dataclass
42+
class BaseModelOutputWithPastAndPredictorLoss(ModelOutput):
43+
loss: Optional[torch.FloatTensor] = None
44+
last_hidden_state: Optional[torch.FloatTensor] = None
45+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
46+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
47+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
48+
49+
50+
class FastLoRAProjection(nn.Module):
51+
def __init__(self, hidden_size, intermediate_size, lora_size):
52+
super().__init__()
53+
self.hidden_size = hidden_size
54+
self.intermediate_size = intermediate_size
55+
self.lora_size = lora_size
56+
# Force creation of linear layers with actual tensors (not meta tensors)
57+
self.down = nn.Linear(hidden_size, lora_size, bias=False)
58+
self.up = nn.Linear(lora_size, intermediate_size, bias=False)
59+
# Pre-allocate buffers on CPU initially
60+
self.register_buffer('intermediate', torch.zeros(1, lora_size))
61+
self.register_buffer('output', torch.zeros(1, intermediate_size))
62+
63+
def to(self, *args, **kwargs):
64+
# Move buffers to same device as model when .to() is called
65+
device = args[0] if args else kwargs.get('device')
66+
67+
if device:
68+
self.intermediate = self.intermediate.to(device)
69+
self.output = self.output.to(device)
70+
return super().to(*args, **kwargs)
71+
72+
def _fix_unloaded_weights(self):
73+
out = self.to_empty(device="cpu")
74+
with torch.no_grad():
75+
torch.nn.init.xavier_normal_(out.down.weight)
76+
torch.nn.init.zeros_(out.up.weight) # Initialize up projection to zeros for stable training
77+
return out
78+
79+
def _resize_buffers(self, batch_size: int, dtype: torch.dtype):
80+
if self.intermediate.size(0) != batch_size:
81+
self.intermediate.resize_(batch_size, self.lora_size)
82+
self.intermediate = self.intermediate.to(dtype=dtype)
83+
self.intermediate.fill_(0.0) # Explicitly initialize with zeros
84+
self.output.resize_(batch_size, self.intermediate_size)
85+
self.output = self.output.to(dtype=dtype)
86+
self.output.fill_(0.0) # Explicitly initialize with zeros
87+
88+
def forward(self, x):
89+
batch_size = x.size(0)
90+
91+
# Check if gradients are required (training mode)
92+
if self.training:
93+
# Use regular matrix multiplication for gradient computation
94+
intermediate = torch.mm(x, self.down.weight.t())
95+
output = torch.mm(intermediate, self.up.weight.t())
96+
return output
97+
else:
98+
# # Use optimized in-place operations for inference
99+
# intermediate = torch.mm(x, self.down.weight.t())
100+
# output = torch.mm(intermediate, self.up.weight.t())
101+
# return output
102+
103+
self._resize_buffers(batch_size, x.dtype)
104+
torch.mm(x, self.down.weight.t(), out=self.intermediate)
105+
torch.mm(self.intermediate, self.up.weight.t(), out=self.output)
106+
return self.output
45107

46108
class SkipMLP(nn.Module):
47109
def __init__(self, hidden_size: int, intermediate_size: int, sparsity: float, bias: bool = False):
@@ -415,6 +477,15 @@ def __init__(self, config):
415477

416478
# Initialize weights and apply final processing
417479
self.post_init()
480+
481+
@classmethod
482+
def from_pretrained(cls, *args, **kwargs):
483+
out = super(SkipConnectionModelForCausalLM, cls).from_pretrained(*args, **kwargs)
484+
for module in out.modules():
485+
if any(hasattr(p, 'is_meta') and p.is_meta for p in module.parameters()) and \
486+
hasattr(module, '_fix_unloaded_weights'):
487+
module = module._fix_unloaded_weights()
488+
return out
418489

419490
def get_input_embeddings(self):
420491
return self.model.embed_tokens

src/modeling_utils.py

Lines changed: 0 additions & 72 deletions
This file was deleted.

src/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from . import llama
22
from . import qwen2
3+
from . import mistral
4+
from . import phi3
35
# from . import dia

src/models/llama/modelling_llama_skip.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@
3737
)
3838

3939
from src.models.llama.configuration_llama_skip import LlamaSkipConnectionConfig
40-
from src.modeling_utils import (
41-
FastLoRAProjection, BaseModelOutputWithPastAndPredictorLoss
42-
)
4340
from src.modeling_skip import SkipMLP, SkipDecoderLayer, build_skip_connection_model, build_skip_connection_model_for_causal_lm
4441

4542
logger = logging.get_logger(__name__)

0 commit comments

Comments
 (0)