Skip to content

Commit 491a7cc

Browse files
vkkhareVarun Khare
andauthored
Training predictors for LLama 3.2 3B (#47)
* add pos_weight balancing for loss Signed-off-by: Varun Khare <varun.khare@niimbleedgehq.ai> * move dataset collection to activation fn and take full samples Signed-off-by: Varun Khare <varun.khare@niimbleedgehq.ai> * increase dataset generation parallelization Signed-off-by: Varun Khare <varun.khare@niimbleedgehq.ai> * correct metric averaging Signed-off-by: Varun Khare <varun.khare@niimbleedgehq.ai> * remove max_batch filter for validation and reduce val size Signed-off-by: Varun Khare <varun.khare@niimbleedgehq.ai> * remove redundant capture hooks Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> * remove redundant sigmoid Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> * update sparsity measurement with activation capture syntax Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> * correct sparsity plots to percentage Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> * add full sparsity values in json dump Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> * clean sparsity measurement in benchmarks Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> * add hyper parameter tuning Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> * add parallel training for predictors with grid search Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> --------- Signed-off-by: Varun Khare <varun.khare@niimbleedgehq.ai> Signed-off-by: Varun Khare <varun.khare@nimbleedgehq.ai> Co-authored-by: Varun Khare <varun.khare@niimbleedgehq.ai>
1 parent 50692df commit 491a7cc

20 files changed

+1618
-932
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ wheels/
2323
.cursorrules
2424
trained_predictors/
2525
wandb/
26+
data/
27+
logs/
2628
# CUDA
2729
*.i
2830
*.ii

benchmarks/llama1b/summary.json

Lines changed: 0 additions & 34 deletions
This file was deleted.

benchmarks/llama3b/summary.json

Lines changed: 0 additions & 46 deletions
This file was deleted.

benchmarks/sparsity.json

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
{
2+
"gate": {
3+
"Llama-3.2-3b-instruct": [
4+
57.71754124164581,
5+
65.11172428131104,
6+
69.67459416389465,
7+
71.62409243583679,
8+
78.40418763160706,
9+
82.82658765316009,
10+
80.40621650218964,
11+
79.38426294326783,
12+
79.12274203300476,
13+
78.16661398410797,
14+
78.77435421943665,
15+
75.91446607112884,
16+
78.38958513736725,
17+
79.7872543811798,
18+
82.00644743442535,
19+
82.67491750717163,
20+
81.88930082321167,
21+
80.33126723766327,
22+
78.93885064125061,
23+
79.20614948272706,
24+
76.78508996963501,
25+
74.55349669456481,
26+
74.1950053691864,
27+
75.14114892482758,
28+
75.33072142601013,
29+
72.14878454208375,
30+
68.71426742076873,
31+
53.61509549617767
32+
],
33+
"Qwen2-1.5b": [
34+
75.93807740211487,
35+
78.41059468984604,
36+
77.49380288124084,
37+
76.96765702962875,
38+
78.98415230512619,
39+
86.08542959094048,
40+
88.52588576078415,
41+
83.89434084892272,
42+
84.28610481023789,
43+
83.33846287727356,
44+
80.04174522161483,
45+
80.4451719880104,
46+
81.62335629463196,
47+
78.81875599622727,
48+
77.72430129647255,
49+
78.92266703248023,
50+
78.07974159121514,
51+
80.27849124073983,
52+
81.47149959802627,
53+
82.16273649334907,
54+
83.16114392280579,
55+
88.331414437294,
56+
91.13149715065956,
57+
89.35581904649734,
58+
88.99205344319344,
59+
84.47382251024246,
60+
77.68900873661042,
61+
87.42764226794243
62+
],
63+
"Llama-3.2-1b-instruct": [
64+
59.69295618534088,
65+
65.60881190299988,
66+
68.45192451477051,
67+
73.27476217746735,
68+
77.34748628139496,
69+
78.48554470539094,
70+
76.55607290267945,
71+
76.5926207780838,
72+
78.33601603507995,
73+
78.57956645488738,
74+
75.88382995128632,
75+
71.21690571308136,
76+
62.98756067752838,
77+
57.79038770198822,
78+
59.951990723609924,
79+
48.912491416931154
80+
],
81+
"Deepseek-r1-distill-qwen-1.5b": [
82+
70.90480007529258,
83+
99.88893811702728,
84+
99.62486154437065,
85+
98.0037947654724,
86+
88.75373544692994,
87+
99.32015951871873,
88+
83.04518195986748,
89+
82.57140514850616,
90+
83.40730249881744,
91+
80.53620984554291,
92+
80.21069329977036,
93+
81.77259366512298,
94+
79.94535800814629,
95+
78.14672967195511,
96+
76.97014119625092,
97+
77.28835244774818,
98+
80.44708899259567,
99+
79.18341723680496,
100+
76.43398404121399,
101+
76.02891492843628,
102+
80.82670731544495,
103+
84.99062582850456,
104+
87.17172874212265,
105+
89.40225526690483,
106+
88.23681662678719,
107+
85.66482260227204,
108+
82.24173642396927,
109+
88.99977133274078
110+
],
111+
"Mistral-7b-instruct-v0.3": [
112+
56.43942987620831,
113+
54.710825201869014,
114+
62.7378715634346,
115+
66.16027640700341,
116+
70.28030137121678,
117+
73.51264708936215,
118+
74.50522061884404,
119+
77.24974319338799,
120+
76.7814871430397,
121+
75.24307530522347,
122+
74.88057094812393,
123+
74.35959818959236,
124+
75.53367558717727,
125+
74.40153402686119,
126+
75.85964927077293,
127+
77.37865363955498,
128+
80.11023232936859,
129+
79.6199797809124,
130+
80.4066299200058,
131+
81.47304484248161,
132+
81.74000000357628,
133+
83.81049889922141,
134+
82.68584475517272,
135+
83.27997506260871,
136+
84.72612695097924,
137+
84.80276737213134,
138+
83.67284808158874,
139+
83.53098798394203,
140+
82.97702146172523,
141+
79.14664803743362,
142+
74.10781040191651,
143+
67.48937454819679
144+
]
145+
},
146+
"up": {
147+
"Llama-3.2-3b-instruct": [
148+
49.94701652526855,
149+
50.25883557796478,
150+
49.861748695373535,
151+
50.212551021575926,
152+
50.13673641681671,
153+
50.268478298187254,
154+
50.23670208454132,
155+
49.97434947490692,
156+
49.92012085914612,
157+
49.747323226928714,
158+
50.086516761779784,
159+
50.03763332366943,
160+
49.882375383377074,
161+
49.97260744571686,
162+
50.26340198516846,
163+
50.02290711402893,
164+
50.01335778236389,
165+
49.795312118530276,
166+
50.07261395454407,
167+
50.176619386672975,
168+
49.79739108085632,
169+
50.07133948802948,
170+
49.866306567192076,
171+
50.10840601921082,
172+
50.19871940612793,
173+
50.12352337837219,
174+
49.68174865245819,
175+
49.686979389190675
176+
],
177+
"Qwen2-1.5b": [
178+
49.966607823967934,
179+
50.07665805220604,
180+
49.93570843935013,
181+
49.86715441942215,
182+
50.08377824127674,
183+
50.042170682549475,
184+
50.093804562091826,
185+
49.93413372039795,
186+
49.97023822367191,
187+
50.044850525259974,
188+
49.94358977377415,
189+
50.04388883113861,
190+
50.019992855191234,
191+
49.94634571969509,
192+
49.9959951788187,
193+
49.87701366841793,
194+
50.004302775859834,
195+
49.91442177593708,
196+
50.008744248747824,
197+
49.8828303784132,
198+
50.030293348431584,
199+
50.190612328052524,
200+
50.29008415937424,
201+
49.93244623243809,
202+
49.84627003967762,
203+
50.10836774408817,
204+
50.078974387049676,
205+
49.67422745227814
206+
],
207+
"Llama-3.2-1b-instruct": [
208+
50.28850176334381,
209+
50.08826148509979,
210+
50.095251607894895,
211+
50.172497272491455,
212+
50.13044228553772,
213+
49.83322007656098,
214+
50.11254615783692,
215+
49.992566442489625,
216+
49.95180006027222,
217+
49.9925139427185,
218+
49.866587257385255,
219+
49.80795896053314,
220+
50.02192153930664,
221+
49.8885840177536,
222+
49.555590176582335,
223+
49.536004185676575
224+
],
225+
"Deepseek-r1-distill-qwen-1.5b": [
226+
49.97109459936619,
227+
49.79854706823826,
228+
49.94962645471096,
229+
50.32353394627571,
230+
49.95378584265709,
231+
50.01829433739185,
232+
49.94671969115734,
233+
49.712460374832155,
234+
49.80023949444294,
235+
50.004930776357654,
236+
50.27471421062946,
237+
49.87664130330086,
238+
50.02333268523216,
239+
50.15037988126278,
240+
49.89214723408222,
241+
50.09376830756664,
242+
49.73434842824936,
243+
49.89638189673424,
244+
50.114672395586965,
245+
50.29278555512428,
246+
50.12480843365192,
247+
49.612388187646864,
248+
49.845171093940735,
249+
49.835383167862894,
250+
49.97257161438465,
251+
49.602730265259744,
252+
49.772731468081474,
253+
49.93913755714893
254+
],
255+
"Mistral-7b-instruct-v0.3": [
256+
49.95004552304745,
257+
49.97675357460976,
258+
49.856904023885726,
259+
49.98892204463482,
260+
50.04397348165512,
261+
49.949986496567725,
262+
49.971290555596354,
263+
49.94282107949257,
264+
49.918927201628684,
265+
50.167772099375725,
266+
50.05580841600895,
267+
50.072830599546435,
268+
50.14451886415482,
269+
49.94904280006885,
270+
49.99744117259979,
271+
49.98993978500366,
272+
50.03798618912697,
273+
49.967672461271285,
274+
49.7877659201622,
275+
50.13071630299091,
276+
50.073597100377086,
277+
50.017177081108095,
278+
49.89373693466187,
279+
50.065289315581325,
280+
50.132445031404494,
281+
49.73462550640106,
282+
50.15417894423008,
283+
49.84468738734722,
284+
49.918292918801306,
285+
50.050922787189485,
286+
50.356959053874014,
287+
50.07077682316303
288+
]
289+
}
290+
}

0 commit comments

Comments
 (0)